diff --git a/Makefile b/Makefile index ef0f74c959..0d19ae3f6f 100644 --- a/Makefile +++ b/Makefile @@ -13,15 +13,11 @@ VERSION := ${AUTV}${VERSION_SUFFIX} VERSION_MM := ${AUTVMINMAJ}${VERSION_SUFFIX} -# dbt runner version info -DBT_AUTV=$(shell python3 -c "from dlt.dbt_runner._version import __version__;print(__version__)") -DBT_AUTVMINMAJ=$(shell python3 -c "from dlt.dbt_runner._version import __version__;print('.'.join(__version__.split('.')[:-1]))") - DBT_NAME := scalevector/dlt-dbt-runner DBT_IMG := ${DBT_NAME}:${TAG} DBT_LATEST := ${DBT_NAME}:latest${VERSION_SUFFIX} -DBT_VERSION := ${DBT_AUTV}${VERSION_SUFFIX} -DBT_VERSION_MM := ${DBT_AUTVMINMAJ}${VERSION_SUFFIX} +DBT_VERSION := ${AUTV}${VERSION_SUFFIX} +DBT_VERSION_MM := ${AUTVMINMAJ}${VERSION_SUFFIX} install-poetry: ifneq ($(VIRTUAL_ENV),) @@ -38,7 +34,7 @@ dev: has-poetry lint: ./check-package.sh - poetry run mypy --config-file mypy.ini dlt examples + poetry run mypy --config-file mypy.ini dlt poetry run flake8 --max-line-length=200 examples dlt poetry run flake8 --max-line-length=200 tests # dlt/pipeline dlt/common/schema dlt/common/normalizers @@ -50,7 +46,7 @@ lint-security: reset-test-storage: -rm -r _storage mkdir _storage - python3 test/tools/create_storages.py + python3 tests/tools/create_storages.py recreate-compiled-deps: poetry export -f requirements.txt --output _gen_requirements.txt --without-hashes --extras gcp --extras redshift diff --git a/dlt/__init__.py b/dlt/__init__.py index b53f5a0b6b..3dc1f76bc6 100644 --- a/dlt/__init__.py +++ b/dlt/__init__.py @@ -1 +1 @@ -from dlt._version import common_version as __version__ \ No newline at end of file +__version__ = "0.1.0" diff --git a/dlt/_version.py b/dlt/_version.py deleted file mode 100644 index ddd0d93607..0000000000 --- a/dlt/_version.py +++ /dev/null @@ -1,3 +0,0 @@ -common_version = "0.1.0" -loader_version = "0.1.0" -normalize_version = "0.1.0" diff --git a/dlt/cli/dlt.py b/dlt/cli/dlt.py index 33759f802e..07dc6ca729 100644 --- a/dlt/cli/dlt.py +++ b/dlt/cli/dlt.py @@ -7,9 +7,8 @@ from dlt.cli import TRunnerArgs from dlt.common.schema import Schema from dlt.common.typing import DictStrAny -from dlt.common.utils import str2bool -from dlt.pipeline import Pipeline, PostgresPipelineCredentials +from dlt.pipeline import pipeline, restore def add_pool_cli_arguments(parser: argparse.ArgumentParser) -> None: @@ -27,33 +26,35 @@ def add_pool_cli_arguments(parser: argparse.ArgumentParser) -> None: def main() -> None: parser = argparse.ArgumentParser(description="Runs various DLT modules", formatter_class=argparse.ArgumentDefaultsHelpFormatter) subparsers = parser.add_subparsers(dest="command") - normalize = subparsers.add_parser("normalize", help="Runs normalize") - add_pool_cli_arguments(normalize) - load = subparsers.add_parser("load", help="Runs loader") - add_pool_cli_arguments(load) + + # normalize = subparsers.add_parser("normalize", help="Runs normalize") + # add_pool_cli_arguments(normalize) + # load = subparsers.add_parser("load", help="Runs loader") + # add_pool_cli_arguments(load) + dbt = subparsers.add_parser("dbt", help="Executes dbt package") add_pool_cli_arguments(dbt) schema = subparsers.add_parser("schema", help="Shows, converts and upgrades schemas") schema.add_argument("file", help="Schema file name, in yaml or json format, will autodetect based on extension") schema.add_argument("--format", choices=["json", "yaml"], default="yaml", help="Display schema in this format") schema.add_argument("--remove-defaults", action="store_true", help="Does not show default hint values") - pipeline = subparsers.add_parser("pipeline", help="Operations on the pipelines") - pipeline.add_argument("name", help="Pipeline name") - pipeline.add_argument("workdir", help="Pipeline working directory") - pipeline.add_argument("operation", choices=["failed_loads"], default="failed_loads", help="Show failed loads for a pipeline") + pipe_cmd = subparsers.add_parser("pipeline", help="Operations on the pipelines") + pipe_cmd.add_argument("name", help="Pipeline name") + pipe_cmd.add_argument("operation", choices=["failed_loads", "drop"], default="failed_loads", help="Show failed loads for a pipeline") + pipe_cmd.add_argument("--workdir", help="Pipeline working directory", default=None) # TODO: consider using fire: https://github.com/google/python-fire # TODO: this also looks better https://click.palletsprojects.com/en/8.1.x/complex/#complex-guide args = parser.parse_args() run_f: Callable[[TRunnerArgs], None] = None - if args.command == "normalize": - from dlt.normalize.normalize import run_main as normalize_run - run_f = normalize_run - elif args.command == "load": - from dlt.load.load import run_main as loader_run - run_f = loader_run - elif args.command == "dbt": + # if args.command == "normalize": + # from dlt.normalize.normalize import run_main as normalize_run + # run_f = normalize_run + # elif args.command == "load": + # from dlt.load.load import run_main as loader_run + # run_f = loader_run + if args.command == "dbt": from dlt.dbt_runner.runner import run_main as dbt_run run_f = dbt_run elif args.command == "schema": @@ -70,13 +71,21 @@ def main() -> None: print(schema_str) exit(0) elif args.command == "pipeline": - p = Pipeline(args.name) - p.restore_pipeline(PostgresPipelineCredentials("dummy"), args.workdir) - completed_loads = p.list_completed_loads() - for load_id in completed_loads: - print(f"Checking failed jobs in {load_id}") - for job, failed_message in p.list_failed_jobs(load_id): - print(f"JOB: {job}\nMSG: {failed_message}") + # from dlt.load import dummy + + p = restore(pipeline_name=args.name, working_dir=args.workdir) + print(f"Found pipeline {p.pipeline_name} ({args.name}) in {p.working_dir} ({args.workdir}) with state {p._get_state()}") + + if args.operation == "failed_loads": + completed_loads = p.list_completed_load_packages() + for load_id in completed_loads: + print(f"Checking failed jobs in load id '{load_id}'") + for job, failed_message in p.list_failed_jobs_in_package(load_id): + print(f"JOB: {os.path.abspath(job)}\nMSG: {failed_message}") + + if args.operation == "drop": + p.drop() + exit(0) else: parser.print_help() diff --git a/dlt/common/__init__.py b/dlt/common/__init__.py index 6da3ee3a0e..7a72b56a9b 100644 --- a/dlt/common/__init__.py +++ b/dlt/common/__init__.py @@ -3,4 +3,3 @@ from .pendulum import pendulum # noqa: F401 from .json import json # noqa: F401, I251 from .time import sleep # noqa: F401 -from dlt._version import common_version as __version__ diff --git a/dlt/common/configuration/__init__.py b/dlt/common/configuration/__init__.py index 0be19399ce..8a2d92346b 100644 --- a/dlt/common/configuration/__init__.py +++ b/dlt/common/configuration/__init__.py @@ -1,11 +1,6 @@ -from .run_configuration import RunConfiguration, BaseConfiguration, CredentialsConfiguration # noqa: F401 -from .normalize_volume_configuration import NormalizeVolumeConfiguration, ProductionNormalizeVolumeConfiguration # noqa: F401 -from .load_volume_configuration import LoadVolumeConfiguration, ProductionLoadVolumeConfiguration # noqa: F401 -from .schema_volume_configuration import SchemaVolumeConfiguration, ProductionSchemaVolumeConfiguration # noqa: F401 -from .pool_runner_configuration import PoolRunnerConfiguration, TPoolType # noqa: F401 -from .gcp_client_credentials import GcpClientCredentials # noqa: F401 -from .postgres_credentials import PostgresCredentials # noqa: F401 -from .utils import make_configuration # noqa: F401 +from .specs.base_configuration import configspec, is_valid_hint # noqa: F401 +from .resolve import resolve_configuration, inject_namespace # noqa: F401 +from .inject import with_config, last_config, get_fun_spec from .exceptions import ( # noqa: F401 - ConfigEntryMissingException, ConfigEnvValueCannotBeCoercedException, ConfigIntegrityException, ConfigFileNotFoundException) + ConfigFieldMissingException, ConfigValueCannotBeCoercedException, ConfigIntegrityException, ConfigFileNotFoundException) diff --git a/dlt/common/configuration/container.py b/dlt/common/configuration/container.py new file mode 100644 index 0000000000..fff80d79ed --- /dev/null +++ b/dlt/common/configuration/container.py @@ -0,0 +1,66 @@ +from contextlib import contextmanager +from typing import Dict, Iterator, Type, TypeVar + +from dlt.common.configuration.specs.base_configuration import ContainerInjectableContext +from dlt.common.configuration.exceptions import ContainerInjectableContextMangled, ContextDefaultCannotBeCreated + +TConfiguration = TypeVar("TConfiguration", bound=ContainerInjectableContext) + + +class Container: + + _INSTANCE: "Container" = None + + contexts: Dict[Type[ContainerInjectableContext], ContainerInjectableContext] + + def __new__(cls: Type["Container"]) -> "Container": + if not cls._INSTANCE: + cls._INSTANCE = super().__new__(cls) + cls._INSTANCE.contexts = {} + return cls._INSTANCE + + def __init__(self) -> None: + pass + + def __getitem__(self, spec: Type[TConfiguration]) -> TConfiguration: + # return existing config object or create it from spec + if not issubclass(spec, ContainerInjectableContext): + raise KeyError(f"{spec.__name__} is not a context") + + item = self.contexts.get(spec) + if item is None: + if spec.can_create_default: + item = spec() + self.contexts[spec] = item + else: + raise ContextDefaultCannotBeCreated(spec) + + return item # type: ignore + + def __setitem__(self, spec: Type[TConfiguration], value: TConfiguration) -> None: + self.contexts[spec] = value + + def __contains__(self, spec: Type[TConfiguration]) -> bool: + return spec in self.contexts + + @contextmanager + def injectable_context(self, config: TConfiguration) -> Iterator[TConfiguration]: + spec = type(config) + previous_config: ContainerInjectableContext = None + if spec in self.contexts: + previous_config = self.contexts[spec] + # set new config and yield context + try: + self.contexts[spec] = config + yield config + finally: + # before setting the previous config for given spec, check if there was no overlapping modification + if self.contexts[spec] is config: + # config is injected for spec so restore previous + if previous_config is None: + del self.contexts[spec] + else: + self.contexts[spec] = previous_config + else: + # value was modified in the meantime and not restored + raise ContainerInjectableContextMangled(spec, self.contexts[spec], config) diff --git a/dlt/common/configuration/exceptions.py b/dlt/common/configuration/exceptions.py index eb859d1b30..c905820e06 100644 --- a/dlt/common/configuration/exceptions.py +++ b/dlt/common/configuration/exceptions.py @@ -1,34 +1,59 @@ -from typing import Iterable, Union +from typing import Any, Mapping, Type, Union, NamedTuple, Sequence from dlt.common.exceptions import DltException +class LookupTrace(NamedTuple): + provider: str + namespaces: Sequence[str] + key: str + value: Any + + class ConfigurationException(DltException): def __init__(self, msg: str) -> None: super().__init__(msg) -class ConfigEntryMissingException(ConfigurationException): - """thrown when not all required config elements are present""" - def __init__(self, missing_set: Iterable[str], namespace: str = None) -> None: - self.missing_set = missing_set - self.namespace = namespace +class ContainerException(ConfigurationException): + """base exception for all exceptions related to injectable container""" + pass + + +class ConfigProviderException(ConfigurationException): + """base exceptions for all exceptions raised by config providers""" + pass + + +class ConfigurationWrongTypeException(ConfigurationException): + def __init__(self, _typ: type) -> None: + super().__init__(f"Invalid configuration instance type {_typ}. Configuration instances must derive from BaseConfiguration.") + - msg = 'Missing config keys: ' + str(missing_set) - if namespace: - msg += ". Note that required namespace for that keys is " + namespace + " and namespace separator is two underscores" +class ConfigFieldMissingException(ConfigurationException): + """thrown when not all required config fields are present""" + + def __init__(self, spec_name: str, traces: Mapping[str, Sequence[LookupTrace]]) -> None: + self.traces = traces + self.spec_name = spec_name + + msg = f"Following fields are missing: {str(list(traces.keys()))} in configuration with spec {spec_name}\n" + for f, field_traces in traces.items(): + msg += f'\tfor field "{f}" config providers and keys were tried in following order:\n' + for tr in field_traces: + msg += f'\t\tIn {tr.provider} key {tr.key} was not found.\n' super().__init__(msg) -class ConfigEnvValueCannotBeCoercedException(ConfigurationException): - """thrown when value from ENV cannot be coerced to hinted type""" +class ConfigValueCannotBeCoercedException(ConfigurationException): + """thrown when value returned by config provider cannot be coerced to hinted type""" - def __init__(self, attr_name: str, env_value: str, hint: type) -> None: - self.attr_name = attr_name - self.env_value = env_value + def __init__(self, field_name: str, field_value: Any, hint: type) -> None: + self.field_name = field_name + self.field_value = field_value self.hint = hint - super().__init__('env value %s cannot be coerced into type %s in attr %s' % (env_value, str(hint), attr_name)) + super().__init__('env value %s cannot be coerced into type %s in attr %s' % (field_value, str(hint), field_name)) class ConfigIntegrityException(ConfigurationException): @@ -46,3 +71,55 @@ class ConfigFileNotFoundException(ConfigurationException): def __init__(self, path: str) -> None: super().__init__(f"Missing config file in {path}") + + +class ConfigFieldMissingTypeHintException(ConfigurationException): + """thrown when configuration specification does not have type hint""" + + def __init__(self, field_name: str, spec: Type[Any]) -> None: + self.field_name = field_name + self.typ_ = spec + super().__init__(f"Field {field_name} on configspec {spec} does not provide required type hint") + + +class ConfigFieldTypeHintNotSupported(ConfigurationException): + """thrown when configuration specification uses not supported type in hint""" + + def __init__(self, field_name: str, spec: Type[Any], typ_: Type[Any]) -> None: + self.field_name = field_name + self.typ_ = spec + super().__init__(f"Field {field_name} on configspec {spec} has hint with unsupported type {typ_}") + + +class ValueNotSecretException(ConfigurationException): + def __init__(self, provider_name: str, key: str) -> None: + self.provider_name = provider_name + self.key = key + super().__init__(f"Provider {provider_name} cannot hold secret values but key {key} with secret value is present") + + +class InvalidInitialValue(ConfigurationException): + def __init__(self, spec: Type[Any], initial_value_type: Type[Any]) -> None: + self.spec = spec + self.initial_value_type = initial_value_type + super().__init__(f"Initial value of type {initial_value_type} is not valid for {spec.__name__}") + + +class ContainerInjectableContextMangled(ContainerException): + def __init__(self, spec: Type[Any], existing_config: Any, expected_config: Any) -> None: + self.spec = spec + self.existing_config = existing_config + self.expected_config = expected_config + super().__init__(f"When restoring context {spec.__name__}, instance {expected_config} was expected, instead instance {existing_config} was found.") + + +class ContextDefaultCannotBeCreated(ContainerException, KeyError): + def __init__(self, spec: Type[Any]) -> None: + self.spec = spec + super().__init__(f"Container cannot create the default value of context {spec.__name__}.") + + +class DuplicateConfigProviderException(ConfigProviderException): + def __init__(self, provider_name: str) -> None: + self.provider_name = provider_name + super().__init__(f"Provider with name {provider_name} already present in ConfigProvidersContext") diff --git a/dlt/common/configuration/gcp_client_credentials.py b/dlt/common/configuration/gcp_client_credentials.py deleted file mode 100644 index 5388eb6e9f..0000000000 --- a/dlt/common/configuration/gcp_client_credentials.py +++ /dev/null @@ -1,33 +0,0 @@ -from dlt.common.typing import StrAny, TSecretValue -from dlt.common.configuration import CredentialsConfiguration - - -class GcpClientCredentials(CredentialsConfiguration): - - __namespace__: str = "GCP" - - PROJECT_ID: str = None - CRED_TYPE: str = "service_account" - PRIVATE_KEY: TSecretValue = None - LOCATION: str = "US" - TOKEN_URI: str = "https://oauth2.googleapis.com/token" - CLIENT_EMAIL: str = None - - HTTP_TIMEOUT: float = 15.0 - RETRY_DEADLINE: float = 600 - - @classmethod - def check_integrity(cls) -> None: - if cls.PRIVATE_KEY and cls.PRIVATE_KEY[-1] != "\n": - # must end with new line, otherwise won't be parsed by Crypto - cls.PRIVATE_KEY = TSecretValue(cls.PRIVATE_KEY + "\n") - - @classmethod - def as_credentials(cls) -> StrAny: - return { - "type": cls.CRED_TYPE, - "project_id": cls.PROJECT_ID, - "private_key": cls.PRIVATE_KEY, - "token_uri": cls.TOKEN_URI, - "client_email": cls.CLIENT_EMAIL - } \ No newline at end of file diff --git a/dlt/common/configuration/inject.py b/dlt/common/configuration/inject.py new file mode 100644 index 0000000000..092f2b234c --- /dev/null +++ b/dlt/common/configuration/inject.py @@ -0,0 +1,171 @@ +import re +import inspect +from makefun import wraps +from types import ModuleType +from typing import Callable, Dict, Type, Any, Optional, Tuple, TypeVar, overload +from inspect import Signature, Parameter + +from dlt.common.typing import StrAny, TFun, AnyFun +from dlt.common.configuration.resolve import resolve_configuration, inject_namespace +from dlt.common.configuration.specs.base_configuration import BaseConfiguration, is_valid_hint, configspec +from dlt.common.configuration.specs.config_namespace_context import ConfigNamespacesContext + +# [^.^_]+ splits by . or _ +_SLEEPING_CAT_SPLIT = re.compile("[^.^_]+") +_LAST_DLT_CONFIG = "_last_dlt_config" +TConfiguration = TypeVar("TConfiguration", bound=BaseConfiguration) +# keep a registry of all the decorated functions +_FUNC_SPECS: Dict[int, Type[BaseConfiguration]] = {} + + +def get_fun_spec(f: AnyFun) -> Type[BaseConfiguration]: + return _FUNC_SPECS.get(id(f)) + + +@overload +def with_config(func: TFun, /, spec: Type[BaseConfiguration] = None, auto_namespace: bool = False, only_kw: bool = False, namespaces: Tuple[str, ...] = ()) -> TFun: + ... + + +@overload +def with_config(func: None = ..., /, spec: Type[BaseConfiguration] = None, auto_namespace: bool = False, only_kw: bool = False, namespaces: Tuple[str, ...] = ()) -> Callable[[TFun], TFun]: + ... + + +def with_config(func: Optional[AnyFun] = None, /, spec: Type[BaseConfiguration] = None, auto_namespace: bool = False, only_kw: bool = False, namespaces: Tuple[str, ...] = ()) -> Callable[[TFun], TFun]: + + namespace_f: Callable[[StrAny], str] = None + # namespace may be a function from function arguments to namespace + if callable(namespaces): + namespace_f = namespaces + + def decorator(f: TFun) -> TFun: + SPEC: Type[BaseConfiguration] = None + sig: Signature = inspect.signature(f) + kwargs_arg = next((p for p in sig.parameters.values() if p.kind == Parameter.VAR_KEYWORD), None) + spec_arg: Parameter = None + pipeline_name_arg: Parameter = None + namespace_context = ConfigNamespacesContext() + + if spec is None: + SPEC = _spec_from_signature(_get_spec_name_from_f(f), inspect.getmodule(f), sig, only_kw) + else: + SPEC = spec + + for p in sig.parameters.values(): + # for all positional parameters that do not have default value, set default + if hasattr(SPEC, p.name) and p.default == Parameter.empty: + p._default = None # type: ignore + if p.annotation is SPEC: + # if any argument has type SPEC then us it to take initial value + spec_arg = p + if p.name == "pipeline_name" and auto_namespace: + # if argument has name pipeline_name and auto_namespace is used, use it to generate namespace context + pipeline_name_arg = p + + + @wraps(f, new_sig=sig) + def _wrap(*args: Any, **kwargs: Any) -> Any: + # bind parameters to signature + bound_args = sig.bind_partial(*args, **kwargs) + bound_args.apply_defaults() + # for calls containing resolved spec in the kwargs, we do not need to resolve again + config: BaseConfiguration = None + if _LAST_DLT_CONFIG in kwargs: + config = last_config(**kwargs) + else: + # if namespace derivation function was provided then call it + nonlocal namespaces + if namespace_f: + namespaces = (namespace_f(bound_args.arguments), ) + # namespaces may be a string + if isinstance(namespaces, str): + namespaces = (namespaces,) + # if one of arguments is spec the use it as initial value + if spec_arg: + config = bound_args.arguments.get(spec_arg.name, None) + # resolve SPEC, also provide namespace_context with pipeline_name + if pipeline_name_arg: + namespace_context.pipeline_name = bound_args.arguments.get(pipeline_name_arg.name, None) + with inject_namespace(namespace_context): + config = resolve_configuration(config or SPEC(), namespaces=namespaces, initial_value=bound_args.arguments) + resolved_params = dict(config) + # overwrite or add resolved params + for p in sig.parameters.values(): + if p.name in resolved_params: + bound_args.arguments[p.name] = resolved_params.pop(p.name) + if p.annotation is SPEC: + bound_args.arguments[p.name] = config + # pass all other config parameters into kwargs if present + if kwargs_arg is not None: + bound_args.arguments[kwargs_arg.name].update(resolved_params) + bound_args.arguments[kwargs_arg.name][_LAST_DLT_CONFIG] = config + # call the function with resolved config + return f(*bound_args.args, **bound_args.kwargs) + + # register the spec for a wrapped function + _FUNC_SPECS[id(_wrap)] = SPEC + + return _wrap # type: ignore + + # See if we're being called as @with_config or @with_config(). + if func is None: + # We're called with parens. + return decorator + + if not callable(func): + raise ValueError("First parameter to the with_config must be callable ie. by using it as function decorator") + + # We're called as @with_config without parens. + return decorator(func) + + +def last_config(**kwargs: Any) -> BaseConfiguration: + return kwargs[_LAST_DLT_CONFIG] # type: ignore + + +def _get_spec_name_from_f(f: AnyFun) -> str: + func_name = f.__qualname__.replace(".", "") # func qual name contains position in the module, separated by dots + + def _first_up(s: str) -> str: + return s[0].upper() + s[1:] + + return "".join(map(_first_up, _SLEEPING_CAT_SPLIT.findall(func_name))) + "Configuration" + + +def _spec_from_signature(name: str, module: ModuleType, sig: Signature, kw_only: bool = False) -> Type[BaseConfiguration]: + # synthesize configuration from the signature + fields: Dict[str, Any] = {} + annotations: Dict[str, Any] = {} + + for p in sig.parameters.values(): + # skip *args and **kwargs, skip typical method params and if kw_only flag is set: accept KEYWORD ONLY args + if p.kind not in (Parameter.VAR_KEYWORD, Parameter.VAR_POSITIONAL) and p.name not in ["self", "cls"] and \ + (kw_only and p.kind == Parameter.KEYWORD_ONLY or not kw_only): + field_type = Any if p.annotation == Parameter.empty else p.annotation + if is_valid_hint(field_type): + field_default = None if p.default == Parameter.empty else p.default + # try to get type from default + if field_type is Any and field_default: + field_type = type(field_default) + # make type optional if explicit None is provided as default + if p.default is None: + field_type = Optional[field_type] + # set annotations + annotations[p.name] = field_type + # set field with default value + + fields[p.name] = field_default + # new type goes to the module where sig was declared + fields["__module__"] = module.__name__ + # set annotations so they are present in __dict__ + fields["__annotations__"] = annotations + # synthesize type + T: Type[BaseConfiguration] = type(name, (BaseConfiguration,), fields) + # add to the module + setattr(module, name, T) + SPEC = configspec(init=False)(T) + # print(f"SYNTHESIZED {SPEC} in {inspect.getmodule(SPEC)} for sig {sig}") + # import dataclasses + # print("\n".join(map(str, dataclasses.fields(SPEC)))) + return SPEC diff --git a/dlt/common/configuration/load_volume_configuration.py b/dlt/common/configuration/load_volume_configuration.py deleted file mode 100644 index 41e1746769..0000000000 --- a/dlt/common/configuration/load_volume_configuration.py +++ /dev/null @@ -1,11 +0,0 @@ -import os - -from dlt.common.configuration.run_configuration import BaseConfiguration - - -class LoadVolumeConfiguration(BaseConfiguration): - LOAD_VOLUME_PATH: str = os.path.join("_storage", "load") # path to volume where files to be loaded to analytical storage are stored - DELETE_COMPLETED_JOBS: bool = False # if set to true the folder with completed jobs will be deleted - -class ProductionLoadVolumeConfiguration(LoadVolumeConfiguration): - LOAD_VOLUME_PATH: str = None diff --git a/dlt/common/configuration/normalize_volume_configuration.py b/dlt/common/configuration/normalize_volume_configuration.py deleted file mode 100644 index 3e3b8c34d6..0000000000 --- a/dlt/common/configuration/normalize_volume_configuration.py +++ /dev/null @@ -1,11 +0,0 @@ -import os - -from dlt.common.configuration import BaseConfiguration - - -class NormalizeVolumeConfiguration(BaseConfiguration): - NORMALIZE_VOLUME_PATH: str = os.path.join("_storage", "normalize") # path to volume where normalized loader files will be stored - - -class ProductionNormalizeVolumeConfiguration(NormalizeVolumeConfiguration): - NORMALIZE_VOLUME_PATH: str = None diff --git a/dlt/common/configuration/pool_runner_configuration.py b/dlt/common/configuration/pool_runner_configuration.py deleted file mode 100644 index 9e900c3fca..0000000000 --- a/dlt/common/configuration/pool_runner_configuration.py +++ /dev/null @@ -1,16 +0,0 @@ -from typing import Literal, Optional -from dlt.common.configuration import RunConfiguration - -TPoolType = Literal["process", "thread", "none"] - - -class PoolRunnerConfiguration(RunConfiguration): - POOL_TYPE: TPoolType = None # type of pool to run, must be set in derived configs - WORKERS: Optional[int] = None # how many threads/processes in the pool - RUN_SLEEP: float = 0.5 # how long to sleep between runs with workload, seconds - RUN_SLEEP_IDLE: float = 1.0 # how long to sleep when no more items are pending, seconds - RUN_SLEEP_WHEN_FAILED: float = 1.0 # how long to sleep between the runs when failed - IS_SINGLE_RUN: bool = False # should run only once until all pending data is processed, and exit - WAIT_RUNS: int = 0 # how many runs to wait for first data coming in is IS_SINGLE_RUN is set - EXIT_ON_EXCEPTION: bool = False # should exit on exception - STOP_AFTER_RUNS: int = 10000 # will stop runner with exit code -2 after so many runs, that prevents memory fragmentation diff --git a/dlt/common/configuration/postgres_credentials.py b/dlt/common/configuration/postgres_credentials.py deleted file mode 100644 index 4b090b6a65..0000000000 --- a/dlt/common/configuration/postgres_credentials.py +++ /dev/null @@ -1,24 +0,0 @@ -from dlt.common.typing import StrAny, TSecretValue -from dlt.common.configuration import CredentialsConfiguration - - -class PostgresCredentials(CredentialsConfiguration): - - __namespace__: str = "PG" - - DBNAME: str = None - PASSWORD: TSecretValue = None - USER: str = None - HOST: str = None - PORT: int = 5439 - CONNECT_TIMEOUT: int = 15 - - @classmethod - def check_integrity(cls) -> None: - cls.DBNAME = cls.DBNAME.lower() - # cls.DEFAULT_DATASET = cls.DEFAULT_DATASET.lower() - cls.PASSWORD = TSecretValue(cls.PASSWORD.strip()) - - @classmethod - def as_credentials(cls) -> StrAny: - return cls.as_dict() diff --git a/dlt/common/configuration/providers/__init__.py b/dlt/common/configuration/providers/__init__.py index e69de29bb2..42488f1b96 100644 --- a/dlt/common/configuration/providers/__init__.py +++ b/dlt/common/configuration/providers/__init__.py @@ -0,0 +1,3 @@ +from .provider import Provider +from .environ import EnvironProvider +from .dictionary import DictionaryProvider \ No newline at end of file diff --git a/dlt/common/configuration/providers/container.py b/dlt/common/configuration/providers/container.py new file mode 100644 index 0000000000..cd1f1a7049 --- /dev/null +++ b/dlt/common/configuration/providers/container.py @@ -0,0 +1,38 @@ +import contextlib +from typing import Any, ClassVar, Optional, Type, Tuple + +from dlt.common.configuration.container import Container +from dlt.common.configuration.specs import ContainerInjectableContext + +from .provider import Provider + + +class ContextProvider(Provider): + + NAME: ClassVar[str] = "Injectable Context" + + def __init__(self) -> None: + self.container = Container() + + @property + def name(self) -> str: + return ContextProvider.NAME + + def get_value(self, key: str, hint: Type[Any], *namespaces: str) -> Tuple[Optional[Any], str]: + assert namespaces == () + + # only context is a valid hint + with contextlib.suppress(TypeError): + if issubclass(hint, ContainerInjectableContext): + # contexts without defaults will raise ContextDefaultCannotBeCreated + return self.container[hint], hint.__name__ + + return None, str(hint) + + @property + def supports_secrets(self) -> bool: + return True + + @property + def supports_namespaces(self) -> bool: + return False diff --git a/dlt/common/configuration/providers/dictionary.py b/dlt/common/configuration/providers/dictionary.py new file mode 100644 index 0000000000..252a2fb216 --- /dev/null +++ b/dlt/common/configuration/providers/dictionary.py @@ -0,0 +1,46 @@ +from contextlib import contextmanager +from typing import Any, ClassVar, Iterator, Optional, Type, Tuple + +from dlt.common.typing import StrAny + +from .provider import Provider + + +class DictionaryProvider(Provider): + + NAME: ClassVar[str] = "Dictionary Provider" + + def __init__(self) -> None: + self._values: StrAny = {} + pass + + @property + def name(self) -> str: + return self.NAME + + def get_value(self, key: str, hint: Type[Any], *namespaces: str) -> Tuple[Optional[Any], str]: + full_path = namespaces + (key,) + full_key = "__".join(full_path) + node = self._values + try: + for k in full_path: + node = node[k] + return node, full_key + except KeyError: + return None, full_key + + @property + def supports_secrets(self) -> bool: + return True + + @property + def supports_namespaces(self) -> bool: + return True + + + @contextmanager + def values(self, v: StrAny) -> Iterator[None]: + p_values = self._values + self._values = v + yield + self._values = p_values diff --git a/dlt/common/configuration/providers/environ.py b/dlt/common/configuration/providers/environ.py index 03a12ba324..2ea3df7a96 100644 --- a/dlt/common/configuration/providers/environ.py +++ b/dlt/common/configuration/providers/environ.py @@ -1,40 +1,58 @@ from os import environ from os.path import isdir -from typing import Any, Optional, Type +from typing import Any, Optional, Type, Tuple from dlt.common.typing import TSecretValue +from .provider import Provider + SECRET_STORAGE_PATH: str = "/run/secrets/%s" +class EnvironProvider(Provider): + + @staticmethod + def get_key_name(key: str, *namespaces: str) -> str: + # env key is always upper case + if namespaces: + namespaces = filter(lambda x: bool(x), namespaces) # type: ignore + env_key = "__".join((*namespaces, key)) + else: + env_key = key + return env_key.upper() + + @property + def name(self) -> str: + return "Environment Variables" + + def get_value(self, key: str, hint: Type[Any], *namespaces: str) -> Tuple[Optional[Any], str]: + # apply namespace to the key + key = self.get_key_name(key, *namespaces) + if hint is TSecretValue: + # try secret storage + try: + # must conform to RFC1123 + secret_name = key.lower().replace("_", "-") + secret_path = SECRET_STORAGE_PATH % secret_name + # kubernetes stores secrets as files in a dir, docker compose plainly + if isdir(secret_path): + secret_path += "/" + secret_name + with open(secret_path, "r", encoding="utf-8") as f: + secret = f.read() + # add secret to environ so forks have access + # TODO: removing new lines is not always good. for password OK for PEMs not + # TODO: in regular secrets that is dealt with in particular configuration logic + environ[key] = secret.strip() + # do not strip returned secret + return secret, key + # includes FileNotFound + except OSError: + pass + return environ.get(key, None), key + + @property + def supports_secrets(self) -> bool: + return True -def get_key_name(key: str, namespace: str = None) -> str: - if namespace: - return namespace + "__" + key - else: - return key - - -def get_key(key: str, hint: Type[Any], namespace: str = None) -> Optional[str]: - # apply namespace to the key - key = get_key_name(key, namespace) - if hint is TSecretValue: - # try secret storage - try: - # must conform to RFC1123 - secret_name = key.lower().replace("_", "-") - secret_path = SECRET_STORAGE_PATH % secret_name - # kubernetes stores secrets as files in a dir, docker compose plainly - if isdir(secret_path): - secret_path += "/" + secret_name - with open(secret_path, "r", encoding="utf-8") as f: - secret = f.read() - # add secret to environ so forks have access - # TODO: removing new lines is not always good. for password OK for PEMs not - # TODO: in regular secrets that is dealt with in particular configuration logic - environ[key] = secret.strip() - # do not strip returned secret - return secret - # includes FileNotFound - except OSError: - pass - return environ.get(key, None) \ No newline at end of file + @property + def supports_namespaces(self) -> bool: + return True diff --git a/dlt/common/configuration/providers/provider.py b/dlt/common/configuration/providers/provider.py new file mode 100644 index 0000000000..9635734f91 --- /dev/null +++ b/dlt/common/configuration/providers/provider.py @@ -0,0 +1,25 @@ +import abc +from typing import Any, Tuple, Type, Optional + + + +class Provider(abc.ABC): + + @abc.abstractmethod + def get_value(self, key: str, hint: Type[Any], *namespaces: str) -> Tuple[Optional[Any], str]: + pass + + @property + @abc.abstractmethod + def supports_secrets(self) -> bool: + pass + + @property + @abc.abstractmethod + def supports_namespaces(self) -> bool: + pass + + @property + @abc.abstractmethod + def name(self) -> str: + pass diff --git a/dlt/common/configuration/providers/toml.py b/dlt/common/configuration/providers/toml.py new file mode 100644 index 0000000000..0f49691cbc --- /dev/null +++ b/dlt/common/configuration/providers/toml.py @@ -0,0 +1,79 @@ +import os +import tomlkit +from typing import Any, Optional, Tuple, Type + +from dlt.common.typing import StrAny + +from .provider import Provider + + +class TomlProvider(Provider): + + def __init__(self, file_name: str, project_dir: str = None) -> None: + self._file_name = file_name + self._toml_path = os.path.join(project_dir or os.path.abspath(os.path.join(".", ".dlt")), file_name) + self._toml = self._read_toml(self._toml_path) + + @staticmethod + def get_key_name(key: str, *namespaces: str) -> str: + # env key is always upper case + if namespaces: + namespaces = filter(lambda x: bool(x), namespaces) # type: ignore + env_key = ".".join((*namespaces, key)) + else: + env_key = key + return env_key + + def get_value(self, key: str, hint: Type[Any], *namespaces: str) -> Tuple[Optional[Any], str]: + full_path = namespaces + (key,) + full_key = self.get_key_name(key, *namespaces) + node = self._toml + try: + for k in full_path: + node = node[k] + return node, full_key + except KeyError: + return None, full_key + + @property + def supports_namespaces(self) -> bool: + return True + + @staticmethod + def _read_toml(toml_path: str) -> StrAny: + if os.path.isfile(toml_path): + # TODO: raise an exception with an explanation to the end user what is this toml file that does not parse etc. + with open(toml_path, "r", encoding="utf-8") as f: + # use whitespace preserving parser + return tomlkit.load(f) + else: + return {} + + +class ConfigTomlProvider(TomlProvider): + + def __init__(self, project_dir: str = None) -> None: + super().__init__("config.toml", project_dir) + + @property + def name(self) -> str: + return "Pipeline config.toml" + + @property + def supports_secrets(self) -> bool: + return False + + + +class SecretsTomlProvider(TomlProvider): + + def __init__(self, project_dir: str = None) -> None: + super().__init__("secrets.toml", project_dir) + + @property + def name(self) -> str: + return "Pipeline secrets.toml" + + @property + def supports_secrets(self) -> bool: + return True diff --git a/dlt/common/configuration/resolve.py b/dlt/common/configuration/resolve.py new file mode 100644 index 0000000000..0e9348fc0e --- /dev/null +++ b/dlt/common/configuration/resolve.py @@ -0,0 +1,316 @@ +import ast +import inspect +from collections.abc import Mapping as C_Mapping +from typing import Any, Dict, ContextManager, List, Optional, Sequence, Tuple, Type, TypeVar, get_origin + +from dlt.common import json, logger +from dlt.common.typing import TSecretValue, is_optional_type, extract_inner_type +from dlt.common.schema.utils import coerce_type, py_type_to_sc_type + +from dlt.common.configuration.specs.base_configuration import BaseConfiguration, CredentialsConfiguration, ContainerInjectableContext +from dlt.common.configuration.specs.config_namespace_context import ConfigNamespacesContext +from dlt.common.configuration.container import Container +from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext +from dlt.common.configuration.providers.container import ContextProvider +from dlt.common.configuration.exceptions import (LookupTrace, ConfigFieldMissingException, ConfigurationWrongTypeException, ConfigValueCannotBeCoercedException, ValueNotSecretException, InvalidInitialValue) + +CHECK_INTEGRITY_F: str = "check_integrity" +TConfiguration = TypeVar("TConfiguration", bound=BaseConfiguration) + + +def resolve_configuration(config: TConfiguration, *, namespaces: Tuple[str, ...] = (), initial_value: Any = None, accept_partial: bool = False) -> TConfiguration: + if not isinstance(config, BaseConfiguration): + raise ConfigurationWrongTypeException(type(config)) + + return _resolve_configuration(config, namespaces, (), initial_value, accept_partial) + + +def deserialize_value(key: str, value: Any, hint: Type[Any]) -> Any: + try: + if hint != Any: + hint_dt = py_type_to_sc_type(hint) + value_dt = py_type_to_sc_type(type(value)) + + # eval only if value is string and hint is "complex" + if value_dt == "text" and hint_dt == "complex": + if hint is tuple: + # use literal eval for tuples + value = ast.literal_eval(value) + else: + # use json for sequences and mappings + value = json.loads(value) + # exact types must match + if not isinstance(value, hint): + raise ValueError(value) + else: + # for types that are not complex, reuse schema coercion rules + if value_dt != hint_dt: + value = coerce_type(hint_dt, value_dt, value) + return value + except ConfigValueCannotBeCoercedException: + raise + except Exception as exc: + raise ConfigValueCannotBeCoercedException(key, value, hint) from exc + + +def serialize_value(value: Any) -> Any: + if value is None: + raise ValueError(value) + # return literal for tuples + if isinstance(value, tuple): + return str(value) + # coerce type to text which will use json for mapping and sequences + value_dt = py_type_to_sc_type(type(value)) + return coerce_type("text", value_dt, value) + + +def inject_namespace(namespace_context: ConfigNamespacesContext, merge_existing: bool = True) -> ContextManager[ConfigNamespacesContext]: + """Adds `namespace` context to container, making it injectable. Optionally merges the context already in the container with the one provided + + Args: + namespace_context (ConfigNamespacesContext): Instance providing a pipeline name and namespace context + merge_existing (bool, optional): Gets `pipeline_name` and `namespaces` from existing context if they are not provided in `namespace` argument. Defaults to True. + + Yields: + Iterator[ConfigNamespacesContext]: Context manager with current namespace context + """ + container = Container() + existing_context = container[ConfigNamespacesContext] + + if merge_existing: + namespace_context.pipeline_name = namespace_context.pipeline_name or existing_context.pipeline_name + namespace_context.namespaces = namespace_context.namespaces or existing_context.namespaces + + return container.injectable_context(namespace_context) + + +def _resolve_configuration( + config: TConfiguration, + explicit_namespaces: Tuple[str, ...], + embedded_namespaces: Tuple[str, ...], + initial_value: Any, + accept_partial: bool + ) -> TConfiguration: + # do not resolve twice + if config.is_resolved(): + return config + + config.__exception__ = None + try: + # if initial value is a Mapping then apply it + if isinstance(initial_value, C_Mapping): + config.update(initial_value) + # cannot be native initial value + initial_value = None + + # try to get the native representation of the configuration using the config namespace as a key + # allows, for example, to store connection string or service.json in their native form in single env variable or under single vault key + resolved_initial: Any = None + if config.__namespace__ or embedded_namespaces: + cf_n, emb_ns = _apply_embedded_namespaces_to_config_namespace(config.__namespace__, embedded_namespaces) + if cf_n: + resolved_initial, traces = _resolve_single_field(cf_n, type(config), None, explicit_namespaces, emb_ns) + _log_traces(config, cf_n, type(config), resolved_initial, traces) + # initial values cannot be dictionaries + if not isinstance(resolved_initial, C_Mapping): + initial_value = resolved_initial or initial_value + # if this is injectable context then return it immediately + if isinstance(resolved_initial, ContainerInjectableContext): + return resolved_initial # type: ignore + try: + try: + # use initial value to set config values + if initial_value: + config.from_native_representation(initial_value) + # if no initial value or initial value was passed via argument, resolve config normally (config always over explicit params) + if not initial_value or not resolved_initial: + raise NotImplementedError() + except ValueError: + raise InvalidInitialValue(type(config), type(initial_value)) + except NotImplementedError: + # if config does not support native form, resolve normally + _resolve_config_fields(config, explicit_namespaces, embedded_namespaces, accept_partial) + + _check_configuration_integrity(config) + # full configuration was resolved + config.__is_resolved__ = True + except ConfigFieldMissingException as cm_ex: + if not accept_partial: + raise + else: + # store the ConfigEntryMissingException to have full info on traces of missing fields + config.__exception__ = cm_ex + except Exception as ex: + # store the exception that happened in the resolution process + config.__exception__ = ex + raise + + return config + + +def _resolve_config_fields( + config: BaseConfiguration, + explicit_namespaces: Tuple[str, ...], + embedded_namespaces: Tuple[str, ...], + accept_partial: bool + ) -> None: + + fields = config.get_resolvable_fields() + unresolved_fields: Dict[str, Sequence[LookupTrace]] = {} + + for key, hint in fields.items(): + # get default value + current_value = getattr(config, key, None) + # check if hint optional + is_optional = is_optional_type(hint) + # accept partial becomes True if type if optional so we do not fail on optional configs that do not resolve fully + accept_partial = accept_partial or is_optional + + # if current value is BaseConfiguration, resolve that instance + if isinstance(current_value, BaseConfiguration): + # resolve only if not yet resolved otherwise just pass it + if not current_value.is_resolved(): + # add key as innermost namespace + current_value = _resolve_configuration(current_value, explicit_namespaces, embedded_namespaces + (key,), None, accept_partial) + else: + # extract hint from Optional / Literal / NewType hints + inner_hint = extract_inner_type(hint) + # extract origin from generic types + inner_hint = get_origin(inner_hint) or inner_hint + + # if inner_hint is BaseConfiguration then resolve it recursively + if inspect.isclass(inner_hint) and issubclass(inner_hint, BaseConfiguration): + # create new instance and pass value from the provider as initial, add key to namespaces + current_value = _resolve_configuration(inner_hint(), explicit_namespaces, embedded_namespaces + (key,), current_value, accept_partial) + else: + + # resolve key value via active providers passing the original hint ie. to preserve TSecretValue + value, traces = _resolve_single_field(key, hint, config.__namespace__, explicit_namespaces, embedded_namespaces) + _log_traces(config, key, hint, value, traces) + # if value is resolved, then deserialize and coerce it + if value is not None: + current_value = deserialize_value(key, value, inner_hint) + + # collect unresolved fields + if not is_optional and current_value is None: + unresolved_fields[key] = traces + # set resolved value in config + setattr(config, key, current_value) + if unresolved_fields: + raise ConfigFieldMissingException(type(config).__name__, unresolved_fields) + + +def _log_traces(config: BaseConfiguration, key: str, hint: Type[Any], value: Any, traces: Sequence[LookupTrace]) -> None: + if logger.is_logging() and logger.log_level() == "DEBUG": + logger.debug(f"Field {key} with type {hint} in {type(config).__name__} {'NOT RESOLVED' if value is None else 'RESOLVED'}") + # print(f"Field {key} with type {hint} in {type(config).__name__} {'NOT RESOLVED' if value is None else 'RESOLVED'}") + for tr in traces: + # print(str(tr)) + logger.debug(str(tr)) + + +def _check_configuration_integrity(config: BaseConfiguration) -> None: + # python multi-inheritance is cooperative and this would require that all configurations cooperatively + # call each other check_integrity. this is not at all possible as we do not know which configs in the end will + # be mixed together. + + # get base classes in order of derivation + mro = type.mro(type(config)) + for c in mro: + # check if this class implements check_integrity (skip pure inheritance to not do double work) + if CHECK_INTEGRITY_F in c.__dict__ and callable(getattr(c, CHECK_INTEGRITY_F)): + # pass right class instance + c.__dict__[CHECK_INTEGRITY_F](config) + + +def _resolve_single_field( + key: str, + hint: Type[Any], + config_namespace: str, + explicit_namespaces: Tuple[str, ...], + embedded_namespaces: Tuple[str, ...] + ) -> Tuple[Optional[Any], List[LookupTrace]]: + container = Container() + # get providers from container + providers = container[ConfigProvidersContext].providers + # get additional namespaces to look in from container + namespaces_context = container[ConfigNamespacesContext] + config_namespace, embedded_namespaces = _apply_embedded_namespaces_to_config_namespace(config_namespace, embedded_namespaces) + + # start looking from the top provider with most specific set of namespaces first + traces: List[LookupTrace] = [] + value = None + + def look_namespaces(pipeline_name: str = None) -> Any: + for provider in providers: + if provider.supports_namespaces: + # if explicit namespaces are provided, ignore the injected context + if explicit_namespaces: + ns = list(explicit_namespaces) + else: + ns = list(namespaces_context.namespaces) + # always extend with embedded namespaces + ns.extend(embedded_namespaces) + else: + # if provider does not support namespaces and pipeline name is set then ignore it + if pipeline_name: + continue + else: + # pass empty namespaces + ns = [] + + value = None + while True: + if (pipeline_name or config_namespace) and provider.supports_namespaces: + full_ns = ns.copy() + # pipeline, when provided, is the most outer and always present + if pipeline_name: + full_ns.insert(0, pipeline_name) + # config namespace, is always present and innermost + if config_namespace: + full_ns.append(config_namespace) + else: + full_ns = ns + value, ns_key = provider.get_value(key, hint, *full_ns) + # if secret is obtained from non secret provider, we must fail + cant_hold_it: bool = not provider.supports_secrets and _is_secret_hint(hint) + if value is not None and cant_hold_it: + raise ValueNotSecretException(provider.name, ns_key) + + # create trace, ignore container provider and providers that cant_hold_it + if provider.name != ContextProvider.NAME and not cant_hold_it: + traces.append(LookupTrace(provider.name, full_ns, ns_key, value)) + + if value is not None: + # value found, ignore other providers + return value + if len(ns) == 0: + # check next provider + break + # pop optional namespaces for less precise lookup + ns.pop() + + # first try with pipeline name as namespace, if present + if namespaces_context.pipeline_name: + value = look_namespaces(namespaces_context.pipeline_name) + # then without it + if value is None: + value = look_namespaces() + + return value, traces + + +def _apply_embedded_namespaces_to_config_namespace(config_namespace: str, embedded_namespaces: Tuple[str, ...]) -> Tuple[str, Tuple[str, ...]]: + # for the configurations that have __namespace__ (config_namespace) defined and are embedded in other configurations, + # the innermost embedded namespace replaces config_namespace + if embedded_namespaces: + # do not add key to embedded namespaces if it starts with _, those namespaces must be ignored + if not embedded_namespaces[-1].startswith("_"): + config_namespace = embedded_namespaces[-1] + embedded_namespaces = embedded_namespaces[:-1] + + return config_namespace, embedded_namespaces + + +def _is_secret_hint(hint: Type[Any]) -> bool: + return hint is TSecretValue or (inspect.isclass(hint) and issubclass(hint, CredentialsConfiguration)) diff --git a/dlt/common/configuration/run_configuration.py b/dlt/common/configuration/run_configuration.py deleted file mode 100644 index 383de508f3..0000000000 --- a/dlt/common/configuration/run_configuration.py +++ /dev/null @@ -1,73 +0,0 @@ -import randomname -from os.path import isfile -from typing import Any, Optional, Tuple, IO - -from dlt.common.typing import StrAny, DictStrAny -from dlt.common.utils import encoding_for_mode -from dlt.common.configuration.exceptions import ConfigFileNotFoundException - -DEVELOPMENT_CONFIG_FILES_STORAGE_PATH = "_storage/config/%s" -PRODUCTION_CONFIG_FILES_STORAGE_PATH = "/run/config/%s" - - -class BaseConfiguration: - - # will be set to true if not all config entries could be resolved - __is_partial__: bool = True - __namespace__: str = None - - @classmethod - def as_dict(config, lowercase: bool = True) -> DictStrAny: - may_lower = lambda k: k.lower() if lowercase else k - return {may_lower(k):getattr(config, k) for k in dir(config) if not callable(getattr(config, k)) and not k.startswith("__")} - - @classmethod - def apply_dict(config, values: StrAny, uppercase: bool = True, apply_non_spec: bool = False) -> None: - if not values: - return - - for k, v in values.items(): - k = k.upper() if uppercase else k - # overwrite only declared values - if not apply_non_spec and hasattr(config, k): - setattr(config, k, v) - - -class CredentialsConfiguration(BaseConfiguration): - pass - - -class RunConfiguration(BaseConfiguration): - PIPELINE_NAME: Optional[str] = None # the name of the component - SENTRY_DSN: Optional[str] = None # keep None to disable Sentry - PROMETHEUS_PORT: Optional[int] = None # keep None to disable Prometheus - LOG_FORMAT: str = '{asctime}|[{levelname:<21}]|{process}|{name}|{filename}|{funcName}:{lineno}|{message}' - LOG_LEVEL: str = "DEBUG" - IS_DEVELOPMENT_CONFIG: bool = True - REQUEST_TIMEOUT: Tuple[int, int] = (15, 300) # default request timeout for all http clients - CONFIG_FILES_STORAGE_PATH: str = DEVELOPMENT_CONFIG_FILES_STORAGE_PATH - - @classmethod - def check_integrity(cls) -> None: - # generate random name if missing - if not cls.PIPELINE_NAME: - cls.PIPELINE_NAME = "dlt_" + randomname.get_name().replace("-", "_") - # if CONFIG_FILES_STORAGE_PATH not overwritten and we are in production mode - if cls.CONFIG_FILES_STORAGE_PATH == DEVELOPMENT_CONFIG_FILES_STORAGE_PATH and not cls.IS_DEVELOPMENT_CONFIG: - # set to mount where config files will be present - cls.CONFIG_FILES_STORAGE_PATH = PRODUCTION_CONFIG_FILES_STORAGE_PATH - - @classmethod - def has_configuration_file(cls, name: str) -> bool: - return isfile(cls.get_configuration_file_path(name)) - - @classmethod - def open_configuration_file(cls, name: str, mode: str) -> IO[Any]: - path = cls.get_configuration_file_path(name) - if not cls.has_configuration_file(name): - raise ConfigFileNotFoundException(path) - return open(path, mode, encoding=encoding_for_mode(mode)) - - @classmethod - def get_configuration_file_path(cls, name: str) -> str: - return cls.CONFIG_FILES_STORAGE_PATH % name \ No newline at end of file diff --git a/dlt/common/configuration/schema_volume_configuration.py b/dlt/common/configuration/schema_volume_configuration.py deleted file mode 100644 index b3018a1782..0000000000 --- a/dlt/common/configuration/schema_volume_configuration.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import Optional, Literal - -from dlt.common.configuration import BaseConfiguration - -TSchemaFileFormat = Literal["json", "yaml"] - - -class SchemaVolumeConfiguration(BaseConfiguration): - SCHEMA_VOLUME_PATH: str = "_storage/schemas" # path to volume with default schemas - IMPORT_SCHEMA_PATH: Optional[str] = None # import schema from external location - EXPORT_SCHEMA_PATH: Optional[str] = None # export schema to external location - EXTERNAL_SCHEMA_FORMAT: TSchemaFileFormat = "yaml" # format in which to expect external schema - EXTERNAL_SCHEMA_FORMAT_REMOVE_DEFAULTS: bool = True # remove default values when exporting schema - - -class ProductionSchemaVolumeConfiguration(SchemaVolumeConfiguration): - SCHEMA_VOLUME_PATH: str = None diff --git a/dlt/common/configuration/specs/__init__.py b/dlt/common/configuration/specs/__init__.py new file mode 100644 index 0000000000..bd48a01909 --- /dev/null +++ b/dlt/common/configuration/specs/__init__.py @@ -0,0 +1,9 @@ +from .run_configuration import RunConfiguration # noqa: F401 +from .base_configuration import BaseConfiguration, CredentialsConfiguration, ContainerInjectableContext # noqa: F401 +from .normalize_volume_configuration import NormalizeVolumeConfiguration # noqa: F401 +from .load_volume_configuration import LoadVolumeConfiguration # noqa: F401 +from .schema_volume_configuration import SchemaVolumeConfiguration, TSchemaFileFormat # noqa: F401 +from .pool_runner_configuration import PoolRunnerConfiguration, TPoolType # noqa: F401 +from .gcp_client_credentials import GcpClientCredentials # noqa: F401 +from .postgres_credentials import PostgresCredentials # noqa: F401 +from .config_namespace_context import ConfigNamespacesContext # noqa: F401 \ No newline at end of file diff --git a/dlt/common/configuration/specs/base_configuration.py b/dlt/common/configuration/specs/base_configuration.py new file mode 100644 index 0000000000..9e98958778 --- /dev/null +++ b/dlt/common/configuration/specs/base_configuration.py @@ -0,0 +1,182 @@ +import inspect +import contextlib +import dataclasses +from typing import Callable, Optional, Union, Any, Dict, Iterator, MutableMapping, Type, TYPE_CHECKING, get_origin, overload, ClassVar + +if TYPE_CHECKING: + TDtcField = dataclasses.Field[Any] +else: + TDtcField = dataclasses.Field + +from dlt.common.typing import TAnyClass, extract_inner_type, is_optional_type +from dlt.common.schema.utils import py_type_to_sc_type +from dlt.common.configuration.exceptions import ConfigFieldMissingTypeHintException, ConfigFieldTypeHintNotSupported + + +def is_valid_hint(hint: Type[Any]) -> bool: + hint = extract_inner_type(hint) + hint = get_origin(hint) or hint + if hint is Any: + return True + if hint is ClassVar: + # class vars are skipped by dataclass + return True + if inspect.isclass(hint) and issubclass(hint, BaseConfiguration): + return True + with contextlib.suppress(TypeError): + py_type_to_sc_type(hint) + return True + return False + + +@overload +def configspec(cls: Type[TAnyClass], /, *, init: bool = False) -> Type[TAnyClass]: + ... + + +@overload +def configspec(cls: None = ..., /, *, init: bool = False) -> Callable[[Type[TAnyClass]], Type[TAnyClass]]: + ... + + +def configspec(cls: Optional[Type[Any]] = None, /, *, init: bool = False) -> Union[Type[TAnyClass], Callable[[Type[TAnyClass]], Type[TAnyClass]]]: + + def wrap(cls: Type[TAnyClass]) -> Type[TAnyClass]: + # if type does not derive from BaseConfiguration then derive it + with contextlib.suppress(NameError): + if not issubclass(cls, BaseConfiguration): + # keep the original module + fields = {"__module__": cls.__module__, "__annotations__": getattr(cls, "__annotations__", {})} + cls = type(cls.__name__, (cls, BaseConfiguration), fields) + # get all annotations without corresponding attributes and set them to None + for ann in cls.__annotations__: + if not hasattr(cls, ann) and not ann.startswith(("__", "_abc_impl")): + setattr(cls, ann, None) + # get all attributes without corresponding annotations + for att_name, att_value in cls.__dict__.items(): + # skip callables, dunder names, class variables and some special names + if not callable(att_value) and not att_name.startswith(("__", "_abc_impl")): + if att_name not in cls.__annotations__: + raise ConfigFieldMissingTypeHintException(att_name, cls) + hint = cls.__annotations__[att_name] + if not is_valid_hint(hint): + raise ConfigFieldTypeHintNotSupported(att_name, cls, hint) + # do not generate repr as it may contain secret values + return dataclasses.dataclass(cls, init=init, eq=False, repr=False) # type: ignore + + # called with parenthesis + if cls is None: + return wrap + + return wrap(cls) + + +@configspec +class BaseConfiguration(MutableMapping[str, Any]): + + # true when all config fields were resolved and have a specified value type + __is_resolved__: bool = dataclasses.field(default = False, init=False, repr=False) + # namespace used by config providers when searching for keys + __namespace__: str = dataclasses.field(default = None, init=False, repr=False) + # holds the exception that prevented the full resolution + __exception__: Exception = dataclasses.field(default = None, init=False, repr=False) + + def from_native_representation(self, native_value: Any) -> None: + """Initialize the configuration fields by parsing the `initial_value` which should be a native representation of the configuration + or credentials, for example database connection string or JSON serialized GCP service credentials file. + + Args: + initial_value (Any): A native representation of the configuration + + Raises: + NotImplementedError: This configuration does not have a native representation + ValueError: The value provided cannot be parsed as native representation + """ + raise NotImplementedError() + + def to_native_representation(self) -> Any: + """Represents the configuration instance in its native form ie. database connection string or JSON serialized GCP service credentials file. + + Raises: + NotImplementedError: This configuration does not have a native representation + + Returns: + Any: A native representation of the configuration + """ + raise NotImplementedError() + + def get_resolvable_fields(self) -> Dict[str, type]: + """Returns a mapping of fields to their type hints. Dunder should not be resolved and are not returned""" + return {f.name:f.type for f in self.__fields_dict().values() if not f.name.startswith("__")} + + def is_resolved(self) -> bool: + return self.__is_resolved__ + + def is_partial(self) -> bool: + """Returns True when any required resolvable field has its value missing.""" + if self.__is_resolved__: + return False + # check if all resolvable fields have value + return any( + field for field, hint in self.get_resolvable_fields().items() if getattr(self, field) is None and not is_optional_type(hint) + ) + + # implement dictionary-compatible interface on top of dataclass + + def __getitem__(self, __key: str) -> Any: + if self.__has_attr(__key): + return getattr(self, __key) + else: + raise KeyError(__key) + + def __setitem__(self, __key: str, __value: Any) -> None: + if self.__has_attr(__key): + setattr(self, __key, __value) + else: + try: + if not self.__ignore_set_unknown_keys: + # assert getattr(self, "__ignore_set_unknown_keys") is not None + raise KeyError(__key) + except AttributeError: + # __ignore_set_unknown_keys attribute may not be present at the moment of checking, __init__ of BaseConfiguration is not typically called + raise KeyError(__key) + + def __delitem__(self, __key: str) -> None: + raise KeyError("Configuration fields cannot be deleted") + + def __iter__(self) -> Iterator[str]: + return filter(lambda k: not k.startswith("__"), self.__fields_dict().__iter__()) + + def __len__(self) -> int: + return sum(1 for _ in self.__iter__()) + + def update(self, other: Any = (), /, **kwds: Any) -> None: + try: + self.__ignore_set_unknown_keys = True + super().update(other, **kwds) + finally: + self.__ignore_set_unknown_keys = False + + # helper functions + + def __has_attr(self, __key: str) -> bool: + return __key in self.__fields_dict() and not __key.startswith("__") + + def __fields_dict(self) -> Dict[str, TDtcField]: + return self.__dataclass_fields__ # type: ignore + + +@configspec +class CredentialsConfiguration(BaseConfiguration): + """Base class for all credentials. Credentials are configurations that may be stored only by providers supporting secrets.""" + + __namespace__: str = "credentials" + + + +@configspec +class ContainerInjectableContext(BaseConfiguration): + """Base class for all configurations that may be injected from Container. Injectable configurations are called contexts""" + + # If True, `Container` is allowed to create default context instance, if none exists + can_create_default: ClassVar[bool] = True diff --git a/dlt/common/configuration/specs/config_namespace_context.py b/dlt/common/configuration/specs/config_namespace_context.py new file mode 100644 index 0000000000..5c4bbd2725 --- /dev/null +++ b/dlt/common/configuration/specs/config_namespace_context.py @@ -0,0 +1,14 @@ +from typing import List, Optional, Tuple, TYPE_CHECKING + +from dlt.common.configuration.specs.base_configuration import ContainerInjectableContext, configspec + + +@configspec(init=True) +class ConfigNamespacesContext(ContainerInjectableContext): + pipeline_name: Optional[str] + namespaces: Tuple[str, ...] = () + + if TYPE_CHECKING: + # provide __init__ signature when type checking + def __init__(self, pipeline_name:str = None, namespaces: Tuple[str, ...] = ()) -> None: + ... diff --git a/dlt/common/configuration/specs/config_providers_context.py b/dlt/common/configuration/specs/config_providers_context.py new file mode 100644 index 0000000000..02cd4d6d37 --- /dev/null +++ b/dlt/common/configuration/specs/config_providers_context.py @@ -0,0 +1,46 @@ + + +from typing import List + +from dlt.common.configuration.exceptions import DuplicateConfigProviderException +from dlt.common.configuration.providers import Provider +from dlt.common.configuration.providers.environ import EnvironProvider +from dlt.common.configuration.providers.container import ContextProvider +from dlt.common.configuration.providers.toml import SecretsTomlProvider, ConfigTomlProvider +from dlt.common.configuration.specs.base_configuration import ContainerInjectableContext, configspec + + +@configspec +class ConfigProvidersContext(ContainerInjectableContext): + """Injectable list of providers used by the configuration `resolve` module""" + providers: List[Provider] + + def __init__(self) -> None: + super().__init__() + # add default providers, ContextProvider must be always first - it will provide contexts + self.providers = [ContextProvider(), EnvironProvider(), SecretsTomlProvider(), ConfigTomlProvider()] + + def __getitem__(self, name: str) -> Provider: + try: + return next(p for p in self.providers if p.name == name) + except StopIteration: + raise KeyError(name) + + def __contains__(self, name: object) -> bool: + try: + self.__getitem__(name) # type: ignore + return True + except KeyError: + return False + + def add_provider(self, provider: Provider) -> None: + if provider.name in self: + raise DuplicateConfigProviderException(provider.name) + self.providers.append(provider) + + +# TODO: implement ConfigProvidersConfiguration and +# @configspec +# class ConfigProvidersConfiguration(BaseConfiguration): +# with_aws_secrets: bool = False +# with_google_secrets: bool = False diff --git a/dlt/common/configuration/specs/gcp_client_credentials.py b/dlt/common/configuration/specs/gcp_client_credentials.py new file mode 100644 index 0000000000..857f0ab97c --- /dev/null +++ b/dlt/common/configuration/specs/gcp_client_credentials.py @@ -0,0 +1,43 @@ +from typing import Any +from dlt.common import json + +from dlt.common.typing import StrAny, TSecretValue +from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec + + +@configspec +class GcpClientCredentials(CredentialsConfiguration): + + project_id: str = None + type: str = "service_account" # noqa: A003 + private_key: TSecretValue = None + location: str = "US" + token_uri: str = "https://oauth2.googleapis.com/token" + client_email: str = None + + http_timeout: float = 15.0 + file_upload_timeout: float = 30 * 60.0 + retry_deadline: float = 600 # how long to retry the operation in case of error, the backoff 60s + + def from_native_representation(self, native_value: Any) -> None: + if not isinstance(native_value, str): + raise ValueError(native_value) + try: + service_dict = json.loads(native_value) + self.update(service_dict) + except Exception: + raise ValueError(native_value) + + def check_integrity(self) -> None: + if self.private_key and self.private_key[-1] != "\n": + # must end with new line, otherwise won't be parsed by Crypto + self.private_key = TSecretValue(self.private_key + "\n") + + def to_native_representation(self) -> StrAny: + return { + "type": self.type, + "project_id": self.project_id, + "private_key": self.private_key, + "token_uri": self.token_uri, + "client_email": self.client_email + } \ No newline at end of file diff --git a/dlt/common/configuration/specs/load_volume_configuration.py b/dlt/common/configuration/specs/load_volume_configuration.py new file mode 100644 index 0000000000..c014a66d43 --- /dev/null +++ b/dlt/common/configuration/specs/load_volume_configuration.py @@ -0,0 +1,13 @@ +from typing import TYPE_CHECKING + +from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec + + +@configspec(init=True) +class LoadVolumeConfiguration(BaseConfiguration): + load_volume_path: str = None # path to volume where files to be loaded to analytical storage are stored + delete_completed_jobs: bool = False # if set to true the folder with completed jobs will be deleted + + if TYPE_CHECKING: + def __init__(self, load_volume_path: str = None, delete_completed_jobs: bool = None) -> None: + ... diff --git a/dlt/common/configuration/specs/normalize_volume_configuration.py b/dlt/common/configuration/specs/normalize_volume_configuration.py new file mode 100644 index 0000000000..49aa40df40 --- /dev/null +++ b/dlt/common/configuration/specs/normalize_volume_configuration.py @@ -0,0 +1,12 @@ +from typing import TYPE_CHECKING + +from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec + + +@configspec(init=True) +class NormalizeVolumeConfiguration(BaseConfiguration): + normalize_volume_path: str = None # path to volume where normalized loader files will be stored + + if TYPE_CHECKING: + def __init__(self, normalize_volume_path: str = None) -> None: + ... diff --git a/dlt/common/configuration/specs/pool_runner_configuration.py b/dlt/common/configuration/specs/pool_runner_configuration.py new file mode 100644 index 0000000000..06a95ceff1 --- /dev/null +++ b/dlt/common/configuration/specs/pool_runner_configuration.py @@ -0,0 +1,18 @@ +from typing import Literal, Optional + +from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec + +TPoolType = Literal["process", "thread", "none"] + + +@configspec +class PoolRunnerConfiguration(BaseConfiguration): + pool_type: TPoolType = None # type of pool to run, must be set in derived configs + workers: Optional[int] = None # how many threads/processes in the pool + run_sleep: float = 0.5 # how long to sleep between runs with workload, seconds + run_sleep_idle: float = 1.0 # how long to sleep when no more items are pending, seconds + run_sleep_when_failed: float = 1.0 # how long to sleep between the runs when failed + is_single_run: bool = False # should run only once until all pending data is processed, and exit + wait_runs: int = 0 # how many runs to wait for first data coming in is IS_SINGLE_RUN is set + exit_on_exception: bool = False # should exit on exception + stop_after_runs: int = 10000 # will stop runner with exit code -2 after so many runs, that prevents memory fragmentation diff --git a/dlt/common/configuration/specs/postgres_credentials.py b/dlt/common/configuration/specs/postgres_credentials.py new file mode 100644 index 0000000000..42d2361183 --- /dev/null +++ b/dlt/common/configuration/specs/postgres_credentials.py @@ -0,0 +1,28 @@ +from typing import Any + +from dlt.common.typing import StrAny, TSecretValue +from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec + + +@configspec +class PostgresCredentials(CredentialsConfiguration): + + dbname: str = None + password: TSecretValue = None + user: str = None + host: str = None + port: int = 5439 + connect_timeout: int = 15 + + def from_native_repesentation(self, initial_value: Any) -> None: + if not isinstance(initial_value, str): + raise ValueError(initial_value) + # TODO: parse postgres connection string + raise NotImplementedError() + + def check_integrity(self) -> None: + self.dbname = self.dbname.lower() + self.password = TSecretValue(self.password.strip()) + + def to_native_representation(self) -> StrAny: + raise NotImplementedError() diff --git a/dlt/common/configuration/specs/run_configuration.py b/dlt/common/configuration/specs/run_configuration.py new file mode 100644 index 0000000000..e19faf1116 --- /dev/null +++ b/dlt/common/configuration/specs/run_configuration.py @@ -0,0 +1,34 @@ +from os.path import isfile +from typing import Any, Optional, Tuple, IO + +from dlt.common.utils import encoding_for_mode, entry_point_file_stem +from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec +from dlt.common.configuration.exceptions import ConfigFileNotFoundException + + +@configspec +class RunConfiguration(BaseConfiguration): + pipeline_name: Optional[str] = None + sentry_dsn: Optional[str] = None # keep None to disable Sentry + prometheus_port: Optional[int] = None # keep None to disable Prometheus + log_format: str = '{asctime}|[{levelname:<21}]|{process}|{name}|{filename}|{funcName}:{lineno}|{message}' + log_level: str = "DEBUG" + request_timeout: Tuple[int, int] = (15, 300) # default request timeout for all http clients + config_files_storage_path: str = "/run/config/%s" + + def check_integrity(self) -> None: + # generate pipeline name from the entry point script name + if not self.pipeline_name: + self.pipeline_name = "dlt_" + (entry_point_file_stem() or "pipeline") + + def has_configuration_file(self, name: str) -> bool: + return isfile(self.get_configuration_file_path(name)) + + def open_configuration_file(self, name: str, mode: str) -> IO[Any]: + path = self.get_configuration_file_path(name) + if not self.has_configuration_file(name): + raise ConfigFileNotFoundException(path) + return open(path, mode, encoding=encoding_for_mode(mode)) + + def get_configuration_file_path(self, name: str) -> str: + return self.config_files_storage_path % name diff --git a/dlt/common/configuration/specs/schema_volume_configuration.py b/dlt/common/configuration/specs/schema_volume_configuration.py new file mode 100644 index 0000000000..a5b70d3068 --- /dev/null +++ b/dlt/common/configuration/specs/schema_volume_configuration.py @@ -0,0 +1,18 @@ +from typing import Optional, Literal, TYPE_CHECKING + +from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec + +TSchemaFileFormat = Literal["json", "yaml"] + + +@configspec(init=True) +class SchemaVolumeConfiguration(BaseConfiguration): + schema_volume_path: str = None # path to volume with default schemas + import_schema_path: Optional[str] = None # import schema from external location + export_schema_path: Optional[str] = None # export schema to external location + external_schema_format: TSchemaFileFormat = "yaml" # format in which to expect external schema + external_schema_format_remove_defaults: bool = True # remove default values when exporting schema + + if TYPE_CHECKING: + def __init__(self, schema_volume_path: str = None, import_schema_path: str = None, export_schema_path: str = None) -> None: + ... diff --git a/dlt/common/configuration/utils.py b/dlt/common/configuration/utils.py deleted file mode 100644 index cced372730..0000000000 --- a/dlt/common/configuration/utils.py +++ /dev/null @@ -1,168 +0,0 @@ -import sys -import semver -from typing import Any, Dict, List, Mapping, Type, TypeVar, cast - -from dlt.common.typing import StrAny, is_optional_type, is_literal_type -from dlt.common.configuration import BaseConfiguration -from dlt.common.configuration.providers import environ -from dlt.common.configuration.exceptions import (ConfigEntryMissingException, - ConfigEnvValueCannotBeCoercedException) -from dlt.common.utils import uniq_id - -SIMPLE_TYPES: List[Any] = [int, bool, list, dict, tuple, bytes, set, float] -# those types and Optionals of those types should not be passed to eval function -NON_EVAL_TYPES = [str, None, Any] -# allows to coerce (type1 from type2) -ALLOWED_TYPE_COERCIONS = [(float, int), (str, int), (str, float)] -IS_DEVELOPMENT_CONFIG_KEY: str = "IS_DEVELOPMENT_CONFIG" -CHECK_INTEGRITY_F: str = "check_integrity" - -TConfiguration = TypeVar("TConfiguration", bound=Type[BaseConfiguration]) -# TODO: remove production configuration support -TProductionConfiguration = TypeVar("TProductionConfiguration", bound=Type[BaseConfiguration]) - - -def make_configuration(config: TConfiguration, - production_config: TProductionConfiguration, - initial_values: StrAny = None, - accept_partial: bool = False, - skip_subclass_check: bool = False) -> TConfiguration: - if not skip_subclass_check: - assert issubclass(production_config, config) - - final_config: TConfiguration = config if _is_development_config() else production_config - possible_keys_in_config = _get_config_attrs_with_hints(final_config) - # create dynamic class type to not touch original config variables - derived_config: TConfiguration = cast(TConfiguration, - type(final_config.__name__ + "_" + uniq_id(), (final_config, ), {}) - ) - # apply initial values while preserving hints - derived_config.apply_dict(initial_values) - - _apply_environ_to_config(derived_config, possible_keys_in_config) - try: - _is_config_bounded(derived_config, possible_keys_in_config) - _check_configuration_integrity(derived_config) - # full configuration was resolved - derived_config.__is_partial__ = False - except ConfigEntryMissingException: - if not accept_partial: - raise - _add_module_version(derived_config) - - return derived_config - - -def is_direct_descendant(child: Type[Any], base: Type[Any]) -> bool: - # TODO: there may be faster way to get direct descendant that mro - # note: at index zero there's child - return base == type.mro(child)[1] - - -def _is_development_config() -> bool: - # get from environment - is_dev_config: bool = None - try: - is_dev_config = _coerce_single_value(IS_DEVELOPMENT_CONFIG_KEY, environ.get_key(IS_DEVELOPMENT_CONFIG_KEY, bool), bool) - except ConfigEnvValueCannotBeCoercedException as coer_exc: - # pass for None: this key may not be present - if coer_exc.env_value is None: - pass - else: - # anything else that cannot corece must raise - raise - return True if is_dev_config is None else is_dev_config - - -def _add_module_version(config: TConfiguration) -> None: - try: - v = sys._getframe(1).f_back.f_globals["__version__"] - semver.VersionInfo.parse(v) - setattr(config, "_VERSION", v) # noqa: B010 - except KeyError: - pass - - -def _apply_environ_to_config(config: TConfiguration, keys_in_config: Mapping[str, type]) -> None: - for key, hint in keys_in_config.items(): - value = environ.get_key(key, hint, config.__namespace__) - if value is not None: - value_from_environment_variable = _coerce_single_value(key, value, hint) - # set value - setattr(config, key, value_from_environment_variable) - - -def _is_config_bounded(config: TConfiguration, keys_in_config: Mapping[str, type]) -> None: - # TODO: here we assume all keys are taken from environ provider, that should change when we introduce more providers - _unbound_attrs = [ - environ.get_key_name(key, config.__namespace__) for key in keys_in_config if getattr(config, key) is None and not is_optional_type(keys_in_config[key]) - ] - - if len(_unbound_attrs) > 0: - raise ConfigEntryMissingException(_unbound_attrs, config.__namespace__) - - -def _check_configuration_integrity(config: TConfiguration) -> None: - # python multi-inheritance is cooperative and this would require that all configurations cooperatively - # call each other check_integrity. this is not at all possible as we do not know which configs in the end will - # be mixed together. - - # get base classes in order of derivation - mro = type.mro(config) - for c in mro: - # check if this class implements check_integrity (skip pure inheritance to not do double work) - if CHECK_INTEGRITY_F in c.__dict__ and callable(getattr(c, CHECK_INTEGRITY_F)): - # access unbounded __func__ to pass right class type so we check settings of the tip of mro - c.__dict__[CHECK_INTEGRITY_F].__func__(config) - - -def _coerce_single_value(key: str, value: str, hint: Type[Any]) -> Any: - try: - hint_primitive_type = _extract_simple_type(hint) - if hint_primitive_type not in NON_EVAL_TYPES: - # create primitive types out of strings - typed_value = eval(value) # nosec - # for primitive types check coercion - if hint_primitive_type in SIMPLE_TYPES and type(typed_value) != hint_primitive_type: - # allow some exceptions - coerce_exception = next( - (e for e in ALLOWED_TYPE_COERCIONS if e == (hint_primitive_type, type(typed_value))), None) - if coerce_exception: - return hint_primitive_type(typed_value) - else: - raise ConfigEnvValueCannotBeCoercedException(key, typed_value, hint) - return typed_value - else: - return value - except ConfigEnvValueCannotBeCoercedException: - raise - except Exception as exc: - raise ConfigEnvValueCannotBeCoercedException(key, value, hint) from exc - - -def _get_config_attrs_with_hints(config: TConfiguration) -> Dict[str, type]: - keys: Dict[str, type] = {} - mro = type.mro(config) - for cls in reversed(mro): - # update in reverse derivation order so derived classes overwrite hints from base classes - if cls is not object: - keys.update( - [(attr, cls.__annotations__.get(attr, None)) - # if hasattr(config, '__annotations__') and attr in config.__annotations__ else None) - for attr in cls.__dict__.keys() if not callable(getattr(cls, attr)) and not attr.startswith("__") - ]) - return keys - - -def _extract_simple_type(hint: Type[Any]) -> Type[Any]: - # extract optional type and call recursively - if is_literal_type(hint): - # assume that all literals are of the same type - return _extract_simple_type(type(hint.__args__[0])) - if is_optional_type(hint): - # todo: use `get_args` in python 3.8 - return _extract_simple_type(hint.__args__[0]) - if not hasattr(hint, "__supertype__"): - return hint - # descend into supertypes of NewType - return _extract_simple_type(hint.__supertype__) diff --git a/dlt/common/data_writers/__init__.py b/dlt/common/data_writers/__init__.py new file mode 100644 index 0000000000..89d4607c90 --- /dev/null +++ b/dlt/common/data_writers/__init__.py @@ -0,0 +1,3 @@ +from dlt.common.data_writers.writers import DataWriter, TLoaderFileFormat +from dlt.common.data_writers.buffered import BufferedDataWriter +from dlt.common.data_writers.escape import escape_redshift_literal, escape_redshift_identifier, escape_bigquery_identifier \ No newline at end of file diff --git a/dlt/common/data_writers/buffered.py b/dlt/common/data_writers/buffered.py new file mode 100644 index 0000000000..1423fa3abe --- /dev/null +++ b/dlt/common/data_writers/buffered.py @@ -0,0 +1,105 @@ +from typing import List, IO, Any, Optional + +from dlt.common.utils import uniq_id +from dlt.common.typing import TDataItem, TDataItems +from dlt.common.data_writers import TLoaderFileFormat +from dlt.common.data_writers.exceptions import BufferedDataWriterClosed, InvalidFileNameTemplateException +from dlt.common.data_writers.writers import DataWriter +from dlt.common.schema.typing import TTableSchemaColumns +from dlt.common.configuration import with_config + + +class BufferedDataWriter: + + @with_config(only_kw=True, namespaces=("data_writer",)) + def __init__(self, file_format: TLoaderFileFormat, file_name_template: str, *, buffer_max_items: int = 5000, file_max_items: int = None, file_max_bytes: int = None): + self.file_format = file_format + self._file_format_spec = DataWriter.data_format_from_file_format(self.file_format) + # validate if template has correct placeholders + self.file_name_template = file_name_template + self.all_files: List[str] = [] + # buffered items must be less than max items in file + self.buffer_max_items = min(buffer_max_items, file_max_items or buffer_max_items) + self.file_max_bytes = file_max_bytes + self.file_max_items = file_max_items + + self._current_columns: TTableSchemaColumns = None + self._file_name: str = None + self._buffered_items: List[TDataItem] = [] + self._writer: DataWriter = None + self._file: IO[Any] = None + self._closed = False + try: + self._rotate_file() + except TypeError: + raise InvalidFileNameTemplateException(file_name_template) + + def write_data_item(self, item: TDataItems, columns: TTableSchemaColumns) -> None: + self._ensure_open() + # rotate file if columns changed and writer does not allow for that + # as the only allowed change is to add new column (no updates/deletes), we detect the change by comparing lengths + if self._writer and not self._writer.data_format().supports_schema_changes and len(columns) != len(self._current_columns): + self._rotate_file() + # until the first chunk is written we can change the columns schema freely + self._current_columns = columns + if isinstance(item, List): + # items coming in single list will be written together, not matter how many are there + self._buffered_items.extend(item) + else: + self._buffered_items.append(item) + # flush if max buffer exceeded + if len(self._buffered_items) >= self.buffer_max_items: + self._flush_items() + # rotate the file if max_bytes exceeded + if self._file: + # rotate on max file size + if self.file_max_bytes and self._file.tell() >= self.file_max_bytes: + self._rotate_file() + # rotate on max items + if self.file_max_items and self._writer.items_count >= self.file_max_items: + self._rotate_file() + + def close_writer(self) -> None: + self._ensure_open() + self._flush_and_close_file() + self._closed = True + + @property + def closed(self) -> bool: + return self._closed + + def _rotate_file(self) -> None: + self._flush_and_close_file() + self._file_name = self.file_name_template % uniq_id(5) + "." + self._file_format_spec.file_extension + + def _flush_items(self) -> None: + if len(self._buffered_items) > 0: + # we only open a writer when there are any files in the buffer and first flush is requested + if not self._writer: + # create new writer and write header + if self._file_format_spec.is_binary_format: + self._file = open(self._file_name, "wb") + else: + self._file = open(self._file_name, "wt", encoding="utf-8") + self._writer = DataWriter.from_file_format(self.file_format, self._file) + self._writer.write_header(self._current_columns) + # write buffer + self._writer.write_data(self._buffered_items) + self._buffered_items.clear() + + def _flush_and_close_file(self) -> None: + # if any buffered items exist, flush them + self._flush_items() + # if writer exists then close it + if self._writer: + # write the footer of a file + self._writer.write_footer() + self._file.close() + # add file written to the list so we can commit all the files later + self.all_files.append(self._file_name) + self._writer = None + self._file = None + + def _ensure_open(self) -> None: + if self._closed: + raise BufferedDataWriterClosed(self._file_name) \ No newline at end of file diff --git a/dlt/common/data_writers/escape.py b/dlt/common/data_writers/escape.py new file mode 100644 index 0000000000..a8cef5e31d --- /dev/null +++ b/dlt/common/data_writers/escape.py @@ -0,0 +1,21 @@ +import re + +# use regex to escape characters in single pass +SQL_ESCAPE_DICT = {"'": "''", "\\": "\\\\", "\n": "\\n", "\r": "\\r"} +SQL_ESCAPE_RE = re.compile("|".join([re.escape(k) for k in sorted(SQL_ESCAPE_DICT, key=len, reverse=True)]), flags=re.DOTALL) + + +def escape_redshift_literal(v: str) -> str: + # https://www.postgresql.org/docs/9.3/sql-syntax-lexical.html + # looks like this is the only thing we need to escape for Postgres > 9.1 + # redshift keeps \ as escape character which is pre 9 behavior + return "{}{}{}".format("'", SQL_ESCAPE_RE.sub(lambda x: SQL_ESCAPE_DICT[x.group(0)], v), "'") + + +def escape_redshift_identifier(v: str) -> str: + return '"' + v.replace('"', '""').replace("\\", "\\\\") + '"' + + +def escape_bigquery_identifier(v: str) -> str: + # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical + return "`" + v.replace("\\", "\\\\").replace("`","\\`") + "`" diff --git a/dlt/common/data_writers/exceptions.py b/dlt/common/data_writers/exceptions.py new file mode 100644 index 0000000000..ffba6a49ba --- /dev/null +++ b/dlt/common/data_writers/exceptions.py @@ -0,0 +1,17 @@ +from dlt.common.exceptions import DltException + + +class DataWriterException(DltException): + pass + + +class InvalidFileNameTemplateException(DataWriterException, ValueError): + def __init__(self, file_name_template: str): + self.file_name_template = file_name_template + super().__init__(f"Wrong file name template {file_name_template}. File name template must contain exactly one %s formatter") + + +class BufferedDataWriterClosed(DataWriterException): + def __init__(self, file_name: str): + self.file_name = file_name + super().__init__(f"Writer with recent file name {file_name} is already closed") diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py new file mode 100644 index 0000000000..399e76c506 --- /dev/null +++ b/dlt/common/data_writers/writers.py @@ -0,0 +1,159 @@ +import abc +import jsonlines +from dataclasses import dataclass +from typing import Any, Dict, Sequence, IO, Literal, Type +from datetime import date, datetime # noqa: I251 + +from dlt.common import json +from dlt.common.typing import StrAny +from dlt.common.json import json_typed_dumps +from dlt.common.schema.typing import TTableSchemaColumns +from dlt.common.data_writers.escape import escape_redshift_identifier, escape_redshift_literal +from dlt.common.destination import TLoaderFileFormat + + +@dataclass +class TFileFormatSpec: + file_format: TLoaderFileFormat + file_extension: str + is_binary_format: bool + supports_schema_changes: bool + + +class DataWriter(abc.ABC): + def __init__(self, f: IO[Any]) -> None: + self._f = f + self.items_count = 0 + + @abc.abstractmethod + def write_header(self, columns_schema: TTableSchemaColumns) -> None: + pass + + def write_data(self, rows: Sequence[Any]) -> None: + self.items_count += len(rows) + + @abc.abstractmethod + def write_footer(self) -> None: + pass + + def write_all(self, columns_schema: TTableSchemaColumns, rows: Sequence[Any]) -> None: + self.write_header(columns_schema) + self.write_data(rows) + self.write_footer() + + + @classmethod + @abc.abstractmethod + def data_format(cls) -> TFileFormatSpec: + pass + + @classmethod + def from_file_format(cls, file_format: TLoaderFileFormat, f: IO[Any]) -> "DataWriter": + return cls.class_factory(file_format)(f) + + @classmethod + def data_format_from_file_format(cls, file_format: TLoaderFileFormat) -> TFileFormatSpec: + return cls.class_factory(file_format).data_format() + + @staticmethod + def class_factory(file_format: TLoaderFileFormat) -> Type["DataWriter"]: + if file_format == "jsonl": + return JsonlWriter + elif file_format == "puae-jsonl": + return JsonlListPUAEncodeWriter + elif file_format == "insert_values": + return InsertValuesWriter + else: + raise ValueError(file_format) + + +class JsonlWriter(DataWriter): + + def write_header(self, columns_schema: TTableSchemaColumns) -> None: + pass + + def write_data(self, rows: Sequence[Any]) -> None: + super().write_data(rows) + # use jsonl to write load files https://jsonlines.org/ + with jsonlines.Writer(self._f, dumps=json.dumps) as w: + w.write_all(rows) + + def write_footer(self) -> None: + pass + + @classmethod + def data_format(cls) -> TFileFormatSpec: + return TFileFormatSpec("jsonl", "jsonl", False, True) + + +class JsonlListPUAEncodeWriter(JsonlWriter): + + def write_data(self, rows: Sequence[Any]) -> None: + # skip JsonlWriter when calling super + super(JsonlWriter, self).write_data(rows) + # encode types with PUA characters + with jsonlines.Writer(self._f, dumps=json_typed_dumps) as w: + # write all rows as one list which will require to write just one line + w.write_all([rows]) + + @classmethod + def data_format(cls) -> TFileFormatSpec: + return TFileFormatSpec("puae-jsonl", "jsonl", False, True) + + +class InsertValuesWriter(DataWriter): + + def __init__(self, f: IO[Any]) -> None: + super().__init__(f) + self._chunks_written = 0 + self._headers_lookup: Dict[str, int] = None + + def write_header(self, columns_schema: TTableSchemaColumns) -> None: + assert self._chunks_written == 0 + headers = columns_schema.keys() + # dict lookup is always faster + self._headers_lookup = {v: i for i, v in enumerate(headers)} + # do not write INSERT INTO command, this must be added together with table name by the loader + self._f.write("INSERT INTO {}(") + self._f.write(",".join(map(escape_redshift_identifier, headers))) + self._f.write(")\nVALUES\n") + + def write_data(self, rows: Sequence[Any]) -> None: + super().write_data(rows) + + def stringify(v: Any) -> str: + if isinstance(v, bytes): + return f"from_hex('{v.hex()}')" + if isinstance(v, (datetime, date)): + return escape_redshift_literal(v.isoformat()) + else: + return str(v) + + def write_row(row: StrAny) -> None: + output = ["NULL"] * len(self._headers_lookup) + for n,v in row.items(): + output[self._headers_lookup[n]] = escape_redshift_literal(v) if isinstance(v, str) else stringify(v) + self._f.write("(") + self._f.write(",".join(output)) + self._f.write(")") + + # if next chunk add separator + if self._chunks_written > 0: + self._f.write(",\n") + + # write rows + for row in rows[:-1]: + write_row(row) + self._f.write(",\n") + + # write last row without separator so we can write footer eventually + write_row(rows[-1]) + self._chunks_written += 1 + + def write_footer(self) -> None: + assert self._chunks_written > 0 + self._f.write(";") + + @classmethod + def data_format(cls) -> TFileFormatSpec: + return TFileFormatSpec("insert_values", "insert_values", False, False) diff --git a/dlt/common/dataset_writers.py b/dlt/common/dataset_writers.py deleted file mode 100644 index 67e1ea130d..0000000000 --- a/dlt/common/dataset_writers.py +++ /dev/null @@ -1,67 +0,0 @@ -import re -import jsonlines -from datetime import date, datetime # noqa: I251 -from typing import Any, Iterable, Literal, Sequence, IO - -from dlt.common import json -from dlt.common.typing import StrAny - -TLoaderFileFormat = Literal["jsonl", "insert_values"] - -# use regex to escape characters in single pass -SQL_ESCAPE_DICT = {"'": "''", "\\": "\\\\", "\n": "\\n", "\r": "\\r"} -SQL_ESCAPE_RE = re.compile("|".join([re.escape(k) for k in sorted(SQL_ESCAPE_DICT, key=len, reverse=True)]), flags=re.DOTALL) - - -def write_jsonl(f: IO[Any], rows: Sequence[Any]) -> None: - # use jsonl to write load files https://jsonlines.org/ - with jsonlines.Writer(f, dumps=json.dumps) as w: - w.write_all(rows) - - -def write_insert_values(f: IO[Any], rows: Sequence[StrAny], headers: Iterable[str]) -> None: - # dict lookup is always faster - headers_lookup = {v: i for i, v in enumerate(headers)} - # do not write INSERT INTO command, this must be added together with table name by the loader - f.write("INSERT INTO {}(") - f.write(",".join(map(escape_redshift_identifier, headers))) - f.write(")\nVALUES\n") - - def stringify(v: Any) -> str: - if isinstance(v, bytes): - return f"from_hex('{v.hex()}')" - if isinstance(v, (datetime, date)): - return escape_redshift_literal(v.isoformat()) - else: - return str(v) - - def write_row(row: StrAny) -> None: - output = ["NULL" for _ in range(len(headers_lookup))] - for n,v in row.items(): - output[headers_lookup[n]] = escape_redshift_literal(v) if isinstance(v, str) else stringify(v) - f.write("(") - f.write(",".join(output)) - f.write(")") - - for row in rows[:-1]: - write_row(row) - f.write(",\n") - - write_row(rows[-1]) - f.write(";") - - -def escape_redshift_literal(v: str) -> str: - # https://www.postgresql.org/docs/9.3/sql-syntax-lexical.html - # looks like this is the only thing we need to escape for Postgres > 9.1 - # redshift keeps \ as escape character which is pre 9 behavior - return "{}{}{}".format("'", SQL_ESCAPE_RE.sub(lambda x: SQL_ESCAPE_DICT[x.group(0)], v), "'") - - -def escape_redshift_identifier(v: str) -> str: - return '"' + v.replace('"', '""').replace("\\", "\\\\") + '"' - - -def escape_bigquery_identifier(v: str) -> str: - # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical - return "`" + v.replace("\\", "\\\\").replace("`","\\`") + "`" diff --git a/dlt/common/destination.py b/dlt/common/destination.py new file mode 100644 index 0000000000..5489ebe5a4 --- /dev/null +++ b/dlt/common/destination.py @@ -0,0 +1,161 @@ +from abc import ABC, abstractmethod +from importlib import import_module +from nis import cat +from types import ModuleType, TracebackType +from typing import ClassVar, List, Optional, Literal, Type, Protocol, Union, TYPE_CHECKING, cast + +from dlt.common.schema import Schema +from dlt.common.schema.typing import TTableSchema +from dlt.common.typing import ConfigValue +from dlt.common.configuration import configspec +from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration, ContainerInjectableContext + + +# known loader file formats +# jsonl - new line separated json documents +# puae-jsonl - internal extract -> normalize format bases on jsonl +# insert_values - insert SQL statements +TLoaderFileFormat = Literal["jsonl", "puae-jsonl", "insert_values"] + + +@configspec(init=True) +class DestinationCapabilitiesContext(ContainerInjectableContext): + """Injectable destination capabilities required for many Pipeline stages ie. normalize""" + preferred_loader_file_format: TLoaderFileFormat + supported_loader_file_formats: List[TLoaderFileFormat] + max_identifier_length: int + max_column_length: int + max_query_length: int + is_max_query_length_in_bytes: bool + max_text_data_type_length: int + is_max_text_data_type_length_in_bytes: bool + + # do not allow to create default value, destination caps must be always explicitly inserted into container + can_create_default: ClassVar[bool] = False + + +@configspec(init=True) +class DestinationClientConfiguration(BaseConfiguration): + destination_name: str = None # which destination to load data to + credentials: Optional[CredentialsConfiguration] + + if TYPE_CHECKING: + def __init__(self, destination_name: str = None, credentials: Optional[CredentialsConfiguration] = None) -> None: + ... + + +@configspec(init=True) +class DestinationClientDwhConfiguration(DestinationClientConfiguration): + dataset_name: str = None # dataset name in the destination to load data to, for schemas that are not default schema, it is used as dataset prefix + default_schema_name: Optional[str] = None # name of default schema to be used to name effective dataset to load data to + + if TYPE_CHECKING: + def __init__( + self, + destination_name: str = None, + credentials: Optional[CredentialsConfiguration] = None, + dataset_name: str = None, + default_schema_name: Optional[str] = None + ) -> None: + ... + + +TLoadJobStatus = Literal["running", "failed", "retry", "completed"] + + +class LoadJob: + """Represents a job that loads a single file + + Each job starts in "running" state and ends in one of terminal states: "retry", "failed" or "completed". + Each job is uniquely identified by a file name. The file is guaranteed to exist in "running" state. In terminal state, the file may not be present. + In "running" state, the loader component periodically gets the state via `status()` method. When terminal state is reached, load job is discarded and not called again. + `exception` method is called to get error information in "failed" and "retry" states. + + The `__init__` method is responsible to put the Job in "running" state. It may raise `LoadClientTerminalException` and `LoadClientTransientException` tp + immediately transition job into "failed" or "retry" state respectively. + """ + def __init__(self, file_name: str) -> None: + """ + File name is also a job id (or job id is deterministically derived) so it must be globally unique + """ + self._file_name = file_name + + @abstractmethod + def status(self) -> TLoadJobStatus: + pass + + @abstractmethod + def file_name(self) -> str: + pass + + @abstractmethod + def exception(self) -> str: + pass + + +class JobClientBase(ABC): + def __init__(self, schema: Schema, config: DestinationClientConfiguration) -> None: + self.schema = schema + self.config = config + + @abstractmethod + def initialize_storage(self, wipe_data: bool = False) -> None: + pass + + @abstractmethod + def update_storage_schema(self) -> None: + pass + + @abstractmethod + def start_file_load(self, table: TTableSchema, file_path: str) -> LoadJob: + pass + + @abstractmethod + def restore_file_load(self, file_path: str) -> LoadJob: + pass + + @abstractmethod + def complete_load(self, load_id: str) -> None: + pass + + @abstractmethod + def __enter__(self) -> "JobClientBase": + pass + + @abstractmethod + def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType) -> None: + pass + + @classmethod + @abstractmethod + def capabilities(cls) -> DestinationCapabilitiesContext: + pass + + +class DestinationReference(Protocol): + __name__: str + + def capabilities(self) -> DestinationCapabilitiesContext: + ... + + def client(self, schema: Schema, initial_config: DestinationClientConfiguration = ConfigValue) -> "JobClientBase": + ... + + def spec(self) -> Type[DestinationClientConfiguration]: + ... + + @staticmethod + def from_name(destination: Union[None, str, "DestinationReference"]) -> "DestinationReference": + if destination is None: + return None + + # if destination is a str, get destination reference by dynamically importing module + if isinstance(destination, str): + if "." in destination: + # this is full module name + return cast(DestinationReference, import_module(destination)) + else: + # from known location + return cast(DestinationReference, import_module(f"dlt.load.{destination}")) + + return destination diff --git a/dlt/common/exceptions.py b/dlt/common/exceptions.py index 1f88dac73b..0d88d8d5f4 100644 --- a/dlt/common/exceptions.py +++ b/dlt/common/exceptions.py @@ -51,21 +51,18 @@ class TerminalException(Exception): """ Marks an exception that cannot be recovered from, should be mixed in into concrete exception class """ - pass class TransientException(Exception): """ Marks an exception in operation that can be retried, should be mixed in into concrete exception class """ - pass class TerminalValueError(ValueError, TerminalException): """ ValueError that is unrecoverable """ - pass class TimeRangeExhaustedException(DltException): @@ -79,8 +76,16 @@ def __init__(self, start_ts: float, end_ts: float) -> None: class DictValidationException(DltException): - def __init__(self, msg: str, path: str, field: str = None, value: Any = None): + def __init__(self, msg: str, path: str, field: str = None, value: Any = None) -> None: self.path = path self.field = field self.value = value super().__init__(msg) + + +class ArgumentsOverloadException(DltException): + def __init__(self, msg: str, func_name: str, *args: str) -> None: + self.func_name = func_name + msg = f"Arguments combination not allowed when calling function {func_name}: {msg}" + msg = "\n".join((msg, *args)) + super().__init__(msg) diff --git a/dlt/common/json.py b/dlt/common/json.py index cb578c0357..6f2d3cf5d7 100644 --- a/dlt/common/json.py +++ b/dlt/common/json.py @@ -2,7 +2,7 @@ import pendulum from datetime import date, datetime # noqa: I251 from functools import partial -from typing import Any, Callable, Union +from typing import Any, Callable, List, Union from uuid import UUID from hexbytes import HexBytes import simplejson @@ -43,22 +43,22 @@ def custom_encode(obj: Any) -> str: # use PUA range to encode additional types -_DECIMAL = u'\uF026' -_DATETIME = u'\uF027' -_DATE = u'\uF028' -_UUIDT = u'\uF029' -_HEXBYTES = u'\uF02A' -_B64BYTES = u'\uF02B' -_WEI = u'\uF02C' +_DECIMAL = '\uF026' +_DATETIME = '\uF027' +_DATE = '\uF028' +_UUIDT = '\uF029' +_HEXBYTES = '\uF02A' +_B64BYTES = '\uF02B' +_WEI = '\uF02C' -DECODERS = [ - lambda s: Decimal(s), - lambda s: pendulum.parse(s), +DECODERS: List[Callable[[Any], Any]] = [ + Decimal, + pendulum.parse, lambda s: pendulum.parse(s).date(), # type: ignore - lambda s: UUID(s), - lambda s: HexBytes(s), - lambda s: base64.b64decode(s), - lambda s: Wei(s) + UUID, + HexBytes, + base64.b64decode, + Wei ] diff --git a/dlt/common/logger.py b/dlt/common/logger.py index 9de24cb809..9a2094a871 100644 --- a/dlt/common/logger.py +++ b/dlt/common/logger.py @@ -2,20 +2,23 @@ import json_logging import traceback import sentry_sdk +from importlib.metadata import version as pkg_version, PackageNotFoundError from sentry_sdk.transport import HttpTransport from sentry_sdk.integrations.logging import LoggingIntegration from logging import LogRecord, Logger -from typing import Any, Type, Protocol +from typing import Any, Protocol from dlt.common.json import json from dlt.common.typing import DictStrAny, StrStr -from dlt.common.configuration import RunConfiguration +from dlt.common.configuration.specs import RunConfiguration from dlt.common.utils import filter_env_vars -from dlt._version import common_version as __version__ -DLT_LOGGER_NAME = "sv-dlt" +from dlt import __version__ + +DLT_LOGGER_NAME = "dlt" LOGGER: Logger = None + def _add_logging_level(level_name: str, level: int, method_name:str = None) -> None: """ Comprehensively adds a new logging level to the `logging` module and the @@ -36,11 +39,11 @@ def _add_logging_level(level_name: str, level: int, method_name:str = None) -> N method_name = level_name.lower() if hasattr(logging, level_name): - raise AttributeError('{} already defined in logging module'.format(level_name)) + raise AttributeError('{} already defined in logging module'.format(level_name)) if hasattr(logging, method_name): - raise AttributeError('{} already defined in logging module'.format(method_name)) + raise AttributeError('{} already defined in logging module'.format(method_name)) if hasattr(logging.getLoggerClass(), method_name): - raise AttributeError('{} already defined in logger class'.format(method_name)) + raise AttributeError('{} already defined in logger class'.format(method_name)) # This method was inspired by the answers to Stack Overflow post # http://stackoverflow.com/q/2183233/2988730, especially @@ -126,11 +129,13 @@ def wrapper(msg: str, *args: Any, **kwargs: Any) -> None: return wrapper -def _extract_version_info(config: Type[RunConfiguration]) -> StrStr: - version_info = {"version": __version__, "component_name": config.PIPELINE_NAME} - version = getattr(config, "_VERSION", None) - if version: - version_info["component_version"] = version +def _extract_version_info(config: RunConfiguration) -> StrStr: + try: + version = pkg_version("python-dlt") + except PackageNotFoundError: + # if there's no package context, take the version from the code + version = __version__ + version_info = {"dlt_version": version, "pipeline_name": config.pipeline_name} # extract envs with build info version_info.update(filter_env_vars(["COMMIT_SHA", "IMAGE_VERSION"])) return version_info @@ -150,8 +155,8 @@ def _get_pool_options(self, *a: Any, **kw: Any) -> DictStrAny: return rv -def _get_sentry_log_level(C: Type[RunConfiguration]) -> LoggingIntegration: - log_level = logging._nameToLevel[C.LOG_LEVEL] +def _get_sentry_log_level(C: RunConfiguration) -> LoggingIntegration: + log_level = logging._nameToLevel[C.log_level] event_level = logging.WARNING if log_level <= logging.WARNING else log_level return LoggingIntegration( level=logging.INFO, # Capture info and above as breadcrumbs @@ -159,14 +164,14 @@ def _get_sentry_log_level(C: Type[RunConfiguration]) -> LoggingIntegration: ) -def _init_sentry(C: Type[RunConfiguration], version: StrStr) -> None: - sys_ver = version["version"] +def _init_sentry(C: RunConfiguration, version: StrStr) -> None: + sys_ver = version["dlt_version"] release = sys_ver + "_" + version.get("commit_sha", "") - _SentryHttpTransport.timeout = C.REQUEST_TIMEOUT[0] + _SentryHttpTransport.timeout = C.request_timeout[0] # TODO: ignore certain loggers ie. dbt loggers # https://docs.sentry.io/platforms/python/guides/logging/ sentry_sdk.init( - C.SENTRY_DSN, + C.sentry_dsn, integrations=[_get_sentry_log_level(C)], release=release, transport=_SentryHttpTransport @@ -180,17 +185,17 @@ def _init_sentry(C: Type[RunConfiguration], version: StrStr) -> None: sentry_sdk.set_tag(k, v) -def init_telemetry(config: Type[RunConfiguration]) -> None: - if config.PROMETHEUS_PORT: +def init_telemetry(config: RunConfiguration) -> None: + if config.prometheus_port: from prometheus_client import start_http_server, Info - logging.info(f"Starting prometheus server port {config.PROMETHEUS_PORT}") - start_http_server(config.PROMETHEUS_PORT) + logging.info(f"Starting prometheus server port {config.prometheus_port}") + start_http_server(config.prometheus_port) # collect info Info("runs_component_name", "Name of the executing component").info(_extract_version_info(config)) -def init_logging_from_config(C: Type[RunConfiguration]) -> None: +def init_logging_from_config(C: RunConfiguration) -> None: global LOGGER # add HEALTH and METRICS log levels @@ -201,14 +206,24 @@ def init_logging_from_config(C: Type[RunConfiguration]) -> None: version = _extract_version_info(C) LOGGER = _init_logging( DLT_LOGGER_NAME, - C.LOG_LEVEL, - C.LOG_FORMAT, - C.PIPELINE_NAME, + C.log_level, + C.log_format, + C.pipeline_name, version) - if C.SENTRY_DSN: + if C.sentry_dsn: _init_sentry(C, version) +def is_logging() -> bool: + return LOGGER is not None + + +def log_level() -> str: + if not LOGGER: + raise RuntimeError("Logger not initialized") + return logging.getLevelName(LOGGER.level) # type: ignore + + def is_json_logging(log_format: str) -> bool: return log_format == "JSON" diff --git a/dlt/common/normalizers/json/__init__.py b/dlt/common/normalizers/json/__init__.py index 0b26371078..81564972c9 100644 --- a/dlt/common/normalizers/json/__init__.py +++ b/dlt/common/normalizers/json/__init__.py @@ -1,7 +1,7 @@ -from typing import Iterator, Tuple, Callable, TYPE_CHECKING +from typing import Any, Iterator, Tuple, Callable, TYPE_CHECKING -from dlt.common.typing import TDataItem, StrAny +from dlt.common.typing import DictStrAny, TDataItem, StrAny if TYPE_CHECKING: from dlt.common.schema import Schema @@ -11,4 +11,9 @@ TNormalizedRowIterator = Iterator[Tuple[Tuple[str, str], StrAny]] # normalization function signature -TNormalizeJSONFunc = Callable[["Schema", TDataItem, str], TNormalizedRowIterator] +TNormalizeJSONFunc = Callable[["Schema", TDataItem, str, str], TNormalizedRowIterator] + + +def wrap_in_dict(item: Any) -> DictStrAny: + """Wraps `item` that is not a dictionary into dictionary that can be json normalized""" + return {"value": item} diff --git a/dlt/common/normalizers/json/relational.py b/dlt/common/normalizers/json/relational.py index 5b1b7c99f0..1d87e48a11 100644 --- a/dlt/common/normalizers/json/relational.py +++ b/dlt/common/normalizers/json/relational.py @@ -5,21 +5,21 @@ from dlt.common.schema.typing import TColumnSchema, TColumnName, TSimpleRegex from dlt.common.schema.utils import column_name_validator from dlt.common.utils import uniq_id, digest128 -from dlt.common.normalizers.json import TNormalizedRowIterator -from dlt.common.sources import DLT_METADATA_FIELD, TEventDLTMeta, get_table_name +from dlt.common.normalizers.json import TNormalizedRowIterator, wrap_in_dict +# from dlt.common.source import TEventDLTMeta from dlt.common.validation import validate_dict -class TEventRow(TypedDict, total=False): +class TDataItemRow(TypedDict, total=False): _dlt_id: str # unique id of current row -class TEventRowRoot(TEventRow, total=False): +class TDataItemRowRoot(TDataItemRow, total=False): _dlt_load_id: str # load id to identify records loaded together that ie. need to be processed - _dlt_meta: TEventDLTMeta # stores metadata, should never be sent to the normalizer + # _dlt_meta: TEventDLTMeta # stores metadata, should never be sent to the normalizer -class TEventRowChild(TEventRow, total=False): +class TDataItemRowChild(TDataItemRow, total=False): _dlt_root_id: str # unique id of top level parent _dlt_parent_id: str # unique id of parent row _dlt_list_idx: int # position in the list of rows @@ -40,7 +40,7 @@ class JSONNormalizerConfig(TypedDict, total=True): # for those paths the complex nested objects should be left in place def _is_complex_type(schema: Schema, table_name: str, field_name: str, _r_lvl: int) -> bool: # turn everything at the recursion level into complex type - max_nesting = schema._normalizers_config["json"].get("config", {}).get("max_nesting", 1000) + max_nesting = (schema._normalizers_config["json"].get("config") or {}).get("max_nesting", 1000) assert _r_lvl <= max_nesting if _r_lvl == max_nesting: return True @@ -48,7 +48,7 @@ def _is_complex_type(schema: Schema, table_name: str, field_name: str, _r_lvl: i column: TColumnSchema = None table = schema._schema_tables.get(table_name) if table: - column = table["columns"].get(field_name, None) + column = table["columns"].get(field_name) if column is None: data_type = schema.get_preferred_type(field_name) else: @@ -56,7 +56,7 @@ def _is_complex_type(schema: Schema, table_name: str, field_name: str, _r_lvl: i return data_type == "complex" -def _flatten(schema: Schema, table: str, dict_row: TEventRow, _r_lvl: int) -> Tuple[TEventRow, Dict[str, Sequence[Any]]]: +def _flatten(schema: Schema, table: str, dict_row: TDataItemRow, _r_lvl: int) -> Tuple[TDataItemRow, Dict[str, Sequence[Any]]]: out_rec_row: DictStrAny = {} out_rec_list: Dict[str, Sequence[Any]] = {} @@ -67,6 +67,7 @@ def norm_row_dicts(dict_row: StrAny, __r_lvl: int, parent_name: Optional[str]) - # for lists and dicts we must check if type is possibly complex if isinstance(v, (dict, list)): if not _is_complex_type(schema, table, child_name, __r_lvl): + # TODO: if schema contains table {table}__{child_name} then convert v into single element list if isinstance(v, dict): # flatten the dict more norm_row_dicts(v, __r_lvl + 1, parent_name=child_name) @@ -81,7 +82,7 @@ def norm_row_dicts(dict_row: StrAny, __r_lvl: int, parent_name: Optional[str]) - out_rec_row[child_name] = v norm_row_dicts(dict_row, _r_lvl, None) - return cast(TEventRow, out_rec_row), out_rec_list + return cast(TDataItemRow, out_rec_row), out_rec_list def _get_child_row_hash(parent_row_id: str, child_table: str, list_idx: int) -> str: @@ -90,26 +91,30 @@ def _get_child_row_hash(parent_row_id: str, child_table: str, list_idx: int) -> return digest128(f"{parent_row_id}_{child_table}_{list_idx}") -def _add_linking(row: TEventRowChild, extend: DictStrAny, parent_row_id: str, list_idx: int) -> TEventRowChild: +def _link_row(row: TDataItemRowChild, parent_row_id: str, list_idx: int) -> TDataItemRowChild: row["_dlt_parent_id"] = parent_row_id row["_dlt_list_idx"] = list_idx return row +def _extend_row(extend: DictStrAny, row: TDataItemRow) -> None: + row.update(extend) # type: ignore + + def _get_content_hash(schema: Schema, table: str, row: StrAny) -> str: return digest128(uniq_id()) -def _get_propagated_values(schema: Schema, table: str, row: TEventRow, is_top_level: bool) -> StrAny: - config: JSONNormalizerConfigPropagation = schema._normalizers_config["json"].get("config", {}).get("propagation", None) +def _get_propagated_values(schema: Schema, table: str, row: TDataItemRow, is_top_level: bool) -> StrAny: + config: JSONNormalizerConfigPropagation = (schema._normalizers_config["json"].get("config") or {}).get("propagation", None) extend: DictStrAny = {} if config: # mapping(k:v): propagate property with name "k" as property with name "v" in child table mappings: DictStrStr = {} if is_top_level: - mappings.update(config.get("root", {})) - if table in config.get("tables", {}): + mappings.update(config.get("root") or {}) + if table in (config.get("tables") or {}): mappings.update(config["tables"][table]) # look for keys and create propagation as values for prop_from, prop_as in mappings.items(): @@ -119,10 +124,6 @@ def _get_propagated_values(schema: Schema, table: str, row: TEventRow, is_top_le return extend -def _extend_row(extend: DictStrAny, row: TEventRow) -> None: - row.update(extend) # type: ignore - - # generate child tables only for lists def _normalize_list( schema: Schema, @@ -134,7 +135,7 @@ def _normalize_list( _r_lvl: int = 0 ) -> TNormalizedRowIterator: - v: TEventRowChild = None + v: TDataItemRowChild = None for idx, v in enumerate(seq): # yield child table row if isinstance(v, dict): @@ -146,14 +147,16 @@ def _normalize_list( else: # list of simple types child_row_hash = _get_child_row_hash(parent_row_id, table, idx) - e = _add_linking({"value": v, "_dlt_id": child_row_hash}, extend, parent_row_id, idx) + wrap_v = wrap_in_dict(v) + wrap_v["_dlt_id"] = child_row_hash + e = _link_row(wrap_v, parent_row_id, idx) _extend_row(extend, e) yield (table, parent_table), e def _normalize_row( schema: Schema, - dict_row: TEventRow, + dict_row: TDataItemRow, extend: DictStrAny, table: str, parent_table: Optional[str] = None, @@ -174,12 +177,12 @@ def _normalize_row( primary_key = schema.filter_row_with_hint(table, "primary_key", flattened_row) if primary_key: # create row id from primary key - row_id = digest128("_".join(map(lambda v: str(v), primary_key.values()))) + row_id = digest128("_".join(map(str, primary_key.values()))) elif not is_top_level: # child table row deterministic hash row_id = _get_child_row_hash(parent_row_id, table, pos) # link to parent table - _add_linking(cast(TEventRowChild, flattened_row), extend, parent_row_id, pos) + _link_row(cast(TDataItemRowChild, flattened_row), parent_row_id, pos) else: # create hash based on the content of the row row_id = _get_content_hash(schema, table, flattened_row) @@ -198,11 +201,11 @@ def _normalize_row( def extend_schema(schema: Schema) -> None: # validate config - config = schema._normalizers_config["json"].get("config", {}) + config = schema._normalizers_config["json"].get("config") or {} validate_dict(JSONNormalizerConfig, config, "./normalizers/json/config", validator_f=column_name_validator(schema.normalize_column_name)) # quick check to see if hints are applied - default_hints = schema.settings.get("default_hints", {}) + default_hints = schema.settings.get("default_hints") or {} if "not_null" in default_hints and "^_dlt_id$" in default_hints["not_null"]: return # add hints @@ -218,14 +221,12 @@ def extend_schema(schema: Schema) -> None: ) -def normalize_data_item(schema: Schema, source_event: TDataItem, load_id: str) -> TNormalizedRowIterator: +def normalize_data_item(schema: Schema, item: TDataItem, load_id: str, table_name: str) -> TNormalizedRowIterator: + # wrap items that are not dictionaries in dictionary, otherwise they cannot be processed by the JSON normalizer + if not isinstance(item, dict): + item = wrap_in_dict(item) # we will extend event with all the fields necessary to load it as root row - event = cast(TEventRowRoot, source_event) + row = cast(TDataItemRowRoot, item) # identify load id if loaded data must be processed after loading incrementally - event["_dlt_load_id"] = load_id - # find table name - table_name = schema.normalize_table_name(get_table_name(event) or schema.name) - # drop dlt metadata before normalizing - event.pop(DLT_METADATA_FIELD, None) # type: ignore - # use event type or schema name as table name, request _dlt_root_id propagation - yield from _normalize_row(schema, cast(TEventRowChild, event), {}, table_name) + row["_dlt_load_id"] = load_id + yield from _normalize_row(schema, cast(TDataItemRowChild, row), {}, schema.normalize_table_name(table_name)) diff --git a/dlt/common/normalizers/names/snake_case.py b/dlt/common/normalizers/names/snake_case.py index 27d1629966..efeb00c0fb 100644 --- a/dlt/common/normalizers/names/snake_case.py +++ b/dlt/common/normalizers/names/snake_case.py @@ -1,5 +1,6 @@ import re from typing import Any, Sequence +from functools import lru_cache RE_UNDERSCORES = re.compile("_+") @@ -15,6 +16,7 @@ # fix a name so it's acceptable as database table name +@lru_cache(maxsize=None) def normalize_table_name(name: str) -> str: if not name: raise ValueError(name) @@ -34,16 +36,22 @@ def camel_to_snake(name: str) -> str: # fix a name so it's an acceptable name for a database column +@lru_cache(maxsize=None) def normalize_column_name(name: str) -> str: # replace consecutive underscores with single one to prevent name clashes with PATH_SEPARATOR return RE_UNDERSCORES.sub("_", normalize_table_name(name)) +# fix a name so it is acceptable as schema name +def normalize_schema_name(name: str) -> str: + return normalize_column_name(name) + + # build full db dataset (dataset) name out of (normalized) default dataset and schema name -def normalize_make_dataset_name(default_dataset: str, default_schema_name: str, schema_name: str) -> str: +def normalize_make_dataset_name(dataset_name: str, default_schema_name: str, schema_name: str) -> str: if schema_name is None: raise ValueError("schema_name is None") - name = normalize_column_name(default_dataset) + name = normalize_column_name(dataset_name) if default_schema_name is None or schema_name != default_schema_name: name += "_" + schema_name diff --git a/dlt/common/pipeline.py b/dlt/common/pipeline.py new file mode 100644 index 0000000000..13781cc23a --- /dev/null +++ b/dlt/common/pipeline.py @@ -0,0 +1,69 @@ +import os +import tempfile +from typing import Any, Callable, ClassVar, Protocol, Sequence + +from dlt.common.configuration.container import ContainerInjectableContext +from dlt.common.configuration import configspec +from dlt.common.destination import DestinationReference +from dlt.common.schema import Schema +from dlt.common.schema.typing import TColumnSchema, TWriteDisposition + + +class SupportsPipeline(Protocol): + """A protocol with core pipeline operations that lets high level abstractions ie. sources to access pipeline methods and properties""" + def run( + self, + source: Any = None, + destination: DestinationReference = None, + dataset_name: str = None, + table_name: str = None, + write_disposition: TWriteDisposition = None, + columns: Sequence[TColumnSchema] = None, + schema: Schema = None + ) -> Any: + ... + + +@configspec(init=True) +class PipelineContext(ContainerInjectableContext): + # TODO: declare unresolvable generic types that will be allowed by configpec + _deferred_pipeline: Any + _pipeline: Any + + can_create_default: ClassVar[bool] = False + + def pipeline(self) -> SupportsPipeline: + """Creates or returns exiting pipeline""" + if not self._pipeline: + # delayed pipeline creation + self._pipeline = self._deferred_pipeline() + return self._pipeline # type: ignore + + def activate(self, pipeline: SupportsPipeline) -> None: + self._pipeline = pipeline + + def is_activated(self) -> bool: + return self._pipeline is not None + + def __init__(self, deferred_pipeline: Callable[..., SupportsPipeline]) -> None: + """Initialize the context with a function returning the Pipeline object to allow creation on first use""" + self._deferred_pipeline = deferred_pipeline + + +def get_default_working_dir() -> str: + """ Gets default working dir of the pipeline, which may be + 1. in user home directory ~/.dlt/pipelines/ + 2. if current user is root in /var/dlt/pipelines + 3. if current user does not have a home directory in /tmp/dlt/pipelines + """ + if os.geteuid() == 0: + # we are root so use standard /var + return os.path.join("/var", "dlt", "pipelines") + + home = os.path.expanduser("~") + if home is None: + # no home dir - use temp + return os.path.join(tempfile.gettempdir(), "dlt", "pipelines") + else: + # if home directory is available use ~/.dlt/pipelines + return os.path.join(home, ".dlt", "pipelines") diff --git a/dlt/common/runners/init.py b/dlt/common/runners/init.py index 508702f2bd..41c536bf82 100644 --- a/dlt/common/runners/init.py +++ b/dlt/common/runners/init.py @@ -1,16 +1,15 @@ import threading -from typing import Type from dlt.common import logger -from dlt.common.configuration.run_configuration import RunConfiguration from dlt.common.logger import init_logging_from_config, init_telemetry from dlt.common.signals import register_signals +from dlt.common.configuration.specs import RunConfiguration # signals and telemetry should be initialized only once _INITIALIZED = False -def initialize_runner(C: Type[RunConfiguration]) -> None: +def initialize_runner(C: RunConfiguration) -> None: global _INITIALIZED # initialize or re-initialize logging with new settings diff --git a/dlt/common/runners/pool_runner.py b/dlt/common/runners/pool_runner.py index c0e159a1af..831dad5cdc 100644 --- a/dlt/common/runners/pool_runner.py +++ b/dlt/common/runners/pool_runner.py @@ -1,16 +1,14 @@ -import argparse import multiprocessing from prometheus_client import Counter, Gauge, Summary, CollectorRegistry, REGISTRY -from typing import Callable, Dict, NamedTuple, Optional, Type, TypeVar, Union, cast +from typing import Callable, Dict, Union, cast from multiprocessing.pool import ThreadPool, Pool from dlt.common import logger, signals from dlt.common.runners.runnable import Runnable, TPool from dlt.common.time import sleep from dlt.common.telemetry import TRunHealth, TRunMetrics, get_logging_extras, get_metrics_from_prometheus -from dlt.common.utils import str2bool from dlt.common.exceptions import SignalReceivedException, TimeRangeExhaustedException, UnsupportedProcessStartMethodException -from dlt.common.configuration import PoolRunnerConfiguration +from dlt.common.configuration.specs import PoolRunnerConfiguration HEALTH_PROPS_GAUGES: Dict[str, Union[Counter, Gauge]] = None @@ -41,23 +39,23 @@ def update_gauges() -> TRunHealth: return get_metrics_from_prometheus(HEALTH_PROPS_GAUGES.values()) # type: ignore -def run_pool(C: Type[PoolRunnerConfiguration], run_f: Union[Runnable[TPool], Callable[[TPool], TRunMetrics]]) -> int: +def run_pool(C: PoolRunnerConfiguration, run_f: Union[Runnable[TPool], Callable[[TPool], TRunMetrics]]) -> int: # create health gauges if not HEALTH_PROPS_GAUGES: create_gauges(REGISTRY) # start pool pool: Pool = None - if C.POOL_TYPE == "process": + if C.pool_type == "process": # our pool implementation do not work on spawn if multiprocessing.get_start_method() != "fork": raise UnsupportedProcessStartMethodException(multiprocessing.get_start_method()) - pool = Pool(processes=C.WORKERS) - elif C.POOL_TYPE == "thread": - pool = ThreadPool(processes=C.WORKERS) + pool = Pool(processes=C.workers) + elif C.pool_type == "thread": + pool = ThreadPool(processes=C.workers) else: pool = None - logger.info(f"Created {C.POOL_TYPE} pool with {C.WORKERS or 'default no.'} workers") + logger.info(f"Created {C.pool_type} pool with {C.workers or 'default no.'} workers") # track local stats runs_count = 0 runs_not_idle_count = 0 @@ -90,7 +88,7 @@ def run_pool(C: Type[PoolRunnerConfiguration], run_f: Union[Runnable[TPool], Cal global LAST_RUN_EXCEPTION LAST_RUN_EXCEPTION = exc # re-raise if EXIT_ON_EXCEPTION is requested - if C.EXIT_ON_EXCEPTION: + if C.exit_on_exception: raise finally: if run_metrics: @@ -103,22 +101,22 @@ def run_pool(C: Type[PoolRunnerConfiguration], run_f: Union[Runnable[TPool], Cal # single run may be forced but at least wait_runs must pass # and was all the time idle or (was not idle but now pending is 0) - if C.IS_SINGLE_RUN and (runs_count >= C.WAIT_RUNS and (runs_not_idle_count == 0 or run_metrics.pending_items == 0)): + if C.is_single_run and (runs_count >= C.wait_runs and (runs_not_idle_count == 0 or run_metrics.pending_items == 0)): logger.info("Stopping runner due to single run override") return 0 if run_metrics.has_failed: - sleep(C.RUN_SLEEP_WHEN_FAILED) + sleep(C.run_sleep_when_failed) elif run_metrics.pending_items == 0: # nothing is pending so we can sleep longer - sleep(C.RUN_SLEEP_IDLE) + sleep(C.run_sleep_idle) else: # more items are pending, sleep (typically) shorter - sleep(C.RUN_SLEEP) + sleep(C.run_sleep) # this allows to recycle long living process that get their memory fragmented # exit after runners sleeps so we keep the running period - if runs_count == C.STOP_AFTER_RUNS: + if runs_count == C.stop_after_runs: logger.warning(f"Stopping runner due to max runs {runs_count} exceeded") return 0 except SignalReceivedException as sigex: @@ -131,9 +129,13 @@ def run_pool(C: Type[PoolRunnerConfiguration], run_f: Union[Runnable[TPool], Cal finally: if pool: logger.info("Closing processing pool") - pool.close() - pool.join() + # terminate pool and do not join + pool.terminate() + # in very rare cases process hangs here, even with starmap terminating earlier + # pool.close() + # pool.join() pool = None + logger.info("Closing processing pool closed") def _update_metrics(run_metrics: TRunMetrics) -> TRunHealth: diff --git a/dlt/common/runners/runnable.py b/dlt/common/runners/runnable.py index ad73cce26d..c3fce997d6 100644 --- a/dlt/common/runners/runnable.py +++ b/dlt/common/runners/runnable.py @@ -4,8 +4,8 @@ from multiprocessing.pool import Pool from weakref import WeakValueDictionary -from dlt.common.telemetry import TRunMetrics from dlt.common.typing import TFun +from dlt.common.telemetry import TRunMetrics TPool = TypeVar("TPool", bound=Pool) @@ -56,3 +56,42 @@ def _wrap(rid: Union[int, Runnable[TPool]], *args: Any, **kwargs: Any) -> Any: return f(rid, *args, **kwargs) return _wrap # type: ignore + + +# def configuredworker(f: TFun) -> TFun: +# """Decorator for a process/thread pool worker function facilitates passing bound configuration type across the process boundary. It requires the first method +# of the worker function to be annotated with type derived from Type[BaseConfiguration] and the worker function to be called (typically by the Pool class) with a +# configuration values serialized to dict (via `as_dict` method). The decorator will synthesize a new derived type and apply the serialized value, mimicking the +# original type to be transferred across the process boundary. + +# Args: +# f (TFun): worker function to be decorated + +# Raises: +# ValueError: raised when worker function signature does not contain required parameters or/and annotations + + +# Returns: +# TFun: wrapped worker function +# """ +# @wraps(f) +# def _wrap(config: Union[StrAny, Type[BaseConfiguration]], *args: Any, **kwargs: Any) -> Any: +# if isinstance(config, Mapping): +# # worker process may run in separate process started with spawn and should not share any state with the parent process ie. global variables like config +# # first function parameter should be of Type[BaseConfiguration] +# sig = inspect.signature(f) +# try: +# first_param: inspect.Parameter = next(iter(sig.parameters.values())) +# T = get_args(first_param.annotation)[0] +# if not issubclass(T, BaseConfiguration): +# raise ValueError(T) +# except Exception: +# raise ValueError(f"First parameter of wrapped worker method {f.__name__} must by annotated as Type[BaseConfiguration]") +# CONFIG = type(f.__name__ + uniq_id(), (T, ), {}) +# CONFIG.apply_dict(config) # type: ignore +# config = CONFIG + +# return f(config, *args, **kwargs) + +# return _wrap # type: ignore + diff --git a/dlt/common/schema/__init__.py b/dlt/common/schema/__init__.py index b3a95af283..80f008f432 100644 --- a/dlt/common/schema/__init__.py +++ b/dlt/common/schema/__init__.py @@ -1,4 +1,4 @@ -from dlt.common.schema.typing import TSchemaUpdate, TStoredSchema, TTableSchemaColumns, TDataType, THintType, TColumnSchema, TColumnSchemaBase # noqa: F401 +from dlt.common.schema.typing import TSchemaUpdate, TStoredSchema, TTableSchemaColumns, TDataType, TColumnHint, TColumnSchema, TColumnSchemaBase # noqa: F401 from dlt.common.schema.typing import COLUMN_HINTS # noqa: F401 from dlt.common.schema.schema import Schema # noqa: F401 -from dlt.common.schema.utils import normalize_schema_name, add_missing_hints, verify_schema_hash # noqa: F401 +from dlt.common.schema.utils import add_missing_hints, verify_schema_hash # noqa: F401 diff --git a/dlt/common/schema/detections.py b/dlt/common/schema/detections.py index 697251de22..49acabf97b 100644 --- a/dlt/common/schema/detections.py +++ b/dlt/common/schema/detections.py @@ -27,6 +27,7 @@ def is_iso_timestamp(t: Type[Any], v: Any) -> Optional[TDataType]: return None # strict autodetection of iso timestamps try: + # TODO: use same functions as in coercions dt = pendulum.parse(v, strict=True, exact=True) if isinstance(dt, datetime.datetime): return "timestamp" diff --git a/dlt/common/schema/exceptions.py b/dlt/common/schema/exceptions.py index 57e43b81cd..cf335b5763 100644 --- a/dlt/common/schema/exceptions.py +++ b/dlt/common/schema/exceptions.py @@ -11,7 +11,7 @@ class SchemaException(DltException): class InvalidSchemaName(SchemaException): def __init__(self, name: str, normalized_name: str) -> None: self.name = name - super().__init__(f"{name} is invalid schema name. Only lowercase letters are allowed. Try {normalized_name} instead") + super().__init__(f"{name} is invalid schema name. Try {normalized_name} instead") class CannotCoerceColumnException(SchemaException): diff --git a/dlt/common/schema/schema.py b/dlt/common/schema/schema.py index db7d14f2a6..2a15ca1a74 100644 --- a/dlt/common/schema/schema.py +++ b/dlt/common/schema/schema.py @@ -9,10 +9,10 @@ from dlt.common.normalizers.json import TNormalizeJSONFunc from dlt.common.schema.typing import (TNormalizersConfig, TPartialTableSchema, TSchemaSettings, TSimpleRegex, TStoredSchema, TSchemaTables, TTableSchema, TTableSchemaColumns, TColumnSchema, TColumnProp, TDataType, - THintType, TWriteDisposition) + TColumnHint, TWriteDisposition) from dlt.common.schema import utils from dlt.common.schema.exceptions import (CannotCoerceColumnException, CannotCoerceNullException, InvalidSchemaName, - ParentTableNotFoundException, SchemaCorruptedException, TablePropertiesConflictException) + ParentTableNotFoundException, SchemaCorruptedException) from dlt.common.validation import validate_dict @@ -24,9 +24,6 @@ class Schema: ENGINE_VERSION = 4 def __init__(self, name: str, normalizers: TNormalizersConfig = None) -> None: - # verify schema name - if name != utils.normalize_schema_name(name): - raise InvalidSchemaName(name, utils.normalize_schema_name(name)) self._schema_tables: TSchemaTables = {} self._schema_name: str = name self._stored_version = 1 # version at load/creation time @@ -38,7 +35,7 @@ def __init__(self, name: str, normalizers: TNormalizersConfig = None) -> None: # list of preferred types: map regex on columns into types self._compiled_preferred_types: List[Tuple[REPattern, TDataType]] = [] # compiled default hints - self._compiled_hints: Dict[THintType, Sequence[REPattern]] = {} + self._compiled_hints: Dict[TColumnHint, Sequence[REPattern]] = {} # compiled exclude filters per table self._compiled_excludes: Dict[str, Sequence[REPattern]] = {} # compiled include filters per table @@ -62,6 +59,8 @@ def __init__(self, name: str, normalizers: TNormalizersConfig = None) -> None: self._add_standard_hints() # configure normalizers, including custom config if present self._configure_normalizers() + # verify schema name after configuring normalizers + self._verify_schema_name(name) # compile all known regexes self._compile_regexes() # set initial version hash @@ -78,10 +77,13 @@ def from_dict(cls, d: DictStrAny) -> "Schema": # bump version if modified utils.bump_version_if_modified(stored_schema) + return cls.from_stored_schema(stored_schema) + @classmethod + def from_stored_schema(cls, stored_schema: TStoredSchema) -> "Schema": # create new instance from dict self: Schema = cls(stored_schema["name"], normalizers=stored_schema.get("normalizers", None)) - self._schema_tables = stored_schema.get("tables", {}) + self._schema_tables = stored_schema.get("tables") or {} if Schema.VERSION_TABLE_NAME not in self._schema_tables: raise SchemaCorruptedException(f"Schema must contain table {Schema.VERSION_TABLE_NAME}") if Schema.LOADS_TABLE_NAME not in self._schema_tables: @@ -89,7 +91,7 @@ def from_dict(cls, d: DictStrAny) -> "Schema": self._stored_version = stored_schema["version"] self._stored_version_hash = stored_schema["version_hash"] self._imported_version_hash = stored_schema.get("imported_version_hash") - self._settings = stored_schema.get("settings", {}) + self._settings = stored_schema.get("settings") or {} # compile regexes self._compile_regexes() @@ -140,7 +142,7 @@ def _exclude(path: str, excludes: Sequence[REPattern], includes: Sequence[REPatt excludes = self._compiled_excludes.get(c_t) # only if there's possibility to exclude, continue if excludes: - includes = self._compiled_includes.get(c_t, []) + includes = self._compiled_includes.get(c_t) or [] for field_name in list(row.keys()): path = self.normalize_make_path(*branch[i:], field_name) if _exclude(path, excludes, includes): @@ -151,12 +153,14 @@ def _exclude(path: str, excludes: Sequence[REPattern], includes: Sequence[REPatt break return row - def coerce_row(self, table_name: str, parent_table: str, row: StrAny) -> Tuple[StrAny, TPartialTableSchema]: + def coerce_row(self, table_name: str, parent_table: str, row: StrAny) -> Tuple[DictStrAny, TPartialTableSchema]: # get existing or create a new table - table = self._schema_tables.get(table_name, utils.new_table(table_name, parent_table)) + updated_table_partial: TPartialTableSchema = None + table = self._schema_tables.get(table_name) + if not table: + table = utils.new_table(table_name, parent_table) table_columns = table["columns"] - partial_table: TPartialTableSchema = None new_row: DictStrAny = {} for col_name, v in row.items(): # skip None values, we should infer the types later @@ -167,12 +171,13 @@ def coerce_row(self, table_name: str, parent_table: str, row: StrAny) -> Tuple[S new_col_name, new_col_def, new_v = self._coerce_non_null_value(table_columns, table_name, col_name, v) new_row[new_col_name] = new_v if new_col_def: - if not partial_table: - partial_table = copy(table) - partial_table["columns"] = {} - partial_table["columns"][new_col_name] = new_col_def + if not updated_table_partial: + # create partial table with only the new columns + updated_table_partial = copy(table) + updated_table_partial["columns"] = {} + updated_table_partial["columns"][new_col_name] = new_col_def - return new_row, partial_table + return new_row, updated_table_partial def update_schema(self, partial_table: TPartialTableSchema) -> None: table_name = partial_table["name"] @@ -189,27 +194,8 @@ def update_schema(self, partial_table: TPartialTableSchema) -> None: # add the whole new table to SchemaTables self._schema_tables[table_name] = partial_table else: - # check if table properties can be merged - if table.get("parent") != partial_table.get("parent"): - raise TablePropertiesConflictException(table_name, "parent", table.get("parent"), partial_table.get("parent")) - # check if partial table has write disposition set - partial_w_d = partial_table.get("write_disposition") - if partial_w_d: - # get write disposition recursively for existing table - existing_w_d = self.get_write_disposition(table_name) - if existing_w_d != partial_w_d: - raise TablePropertiesConflictException(table_name, "write_disposition", existing_w_d, partial_w_d) - # add several columns to existing table - table_columns = table["columns"] - for column in partial_table["columns"].values(): - column_name = column["name"] - if column_name in table_columns: - # we do not support changing existing columns - if not utils.compare_columns(table_columns[column_name], column): - # attempt to update to incompatible columns - raise CannotCoerceColumnException(table_name, column_name, column["data_type"], table_columns[column_name]["data_type"], None) - else: - table_columns[column_name] = column + # merge tables performing additional checks + utils.merge_tables(table, partial_table) def bump_version(self) -> Tuple[int, str]: """Computes schema hash in order to check if schema content was modified. In such case the schema ``stored_version`` and ``stored_version_hash`` are updated. @@ -223,7 +209,7 @@ def bump_version(self) -> Tuple[int, str]: self._stored_version, self._stored_version_hash = version return version - def filter_row_with_hint(self, table_name: str, hint_type: THintType, row: StrAny) -> StrAny: + def filter_row_with_hint(self, table_name: str, hint_type: TColumnHint, row: StrAny) -> StrAny: rv_row: DictStrAny = {} column_prop: TColumnProp = utils.hint_to_column_prop(hint_type) try: @@ -241,7 +227,7 @@ def filter_row_with_hint(self, table_name: str, hint_type: THintType, row: StrAn # dicts are ordered and we will return the rows with hints in the same order as they appear in the columns return rv_row - def merge_hints(self, new_hints: Mapping[THintType, Sequence[TSimpleRegex]]) -> None: + def merge_hints(self, new_hints: Mapping[TColumnHint, Sequence[TSimpleRegex]]) -> None: # validate regexes validate_dict(TSchemaSettings, {"default_hints": new_hints}, ".", validator_f=utils.simple_regex_validator) # prepare hints to be added @@ -256,7 +242,21 @@ def merge_hints(self, new_hints: Mapping[THintType, Sequence[TSimpleRegex]]) -> default_hints[h] = l # type: ignore self._compile_regexes() - def get_schema_update_for(self, table_name: str, t: TTableSchemaColumns) -> List[TColumnSchema]: + def normalize_table_identifiers(self, table: TTableSchema) -> TTableSchema: + # normalize all identifiers in table according to name normalizer of the schema + table["name"] = self.normalize_table_name(table["name"]) + parent = table.get("parent") + if parent: + table["parent"] = self.normalize_table_name(parent) + columns = table.get("columns") + if columns: + for c in columns.values(): + c["name"] = self.normalize_column_name(c["name"]) + # re-index columns as the name changed + table["columns"] = {c["name"]:c for c in columns.values()} + return table + + def get_new_columns(self, table_name: str, t: TTableSchemaColumns) -> List[TColumnSchema]: # gets new columns to be added to "t" to bring up to date with stored schema diff_c: List[TColumnSchema] = [] s_t = self.get_table_columns(table_name) @@ -356,7 +356,7 @@ def _coerce_non_null_value(self, table_columns: TTableSchemaColumns, table_name: # infer type or get it from existing table col_type = existing_column.get("data_type") if existing_column else self._infer_column_type(v, col_name) - # get real python type + # get data type of value py_type = utils.py_type_to_sc_type(type(v)) # and coerce type if inference changed the python type try: @@ -373,7 +373,8 @@ def _coerce_non_null_value(self, table_columns: TTableSchemaColumns, table_name: return self._coerce_non_null_value(table_columns, table_name, variant_col_name, v, final=True) # if coerced value is variant, then extract variant value - if isinstance(coerced_v, SupportsVariant): + # note: checking runtime protocols with isinstance(coerced_v, SupportsVariant): is extremely slow so we check if callable as every variant is callable + if callable(coerced_v): # and isinstance(coerced_v, SupportsVariant): coerced_v = coerced_v() if isinstance(coerced_v, tuple): # variant recovered so call recursively with variant column name and variant value @@ -406,7 +407,7 @@ def _infer_column_type(self, v: Any, col_name: str) -> TDataType: pass return mapped_type - def _infer_hint(self, hint_type: THintType, _: Any, col_name: str) -> bool: + def _infer_hint(self, hint_type: TColumnHint, _: Any, col_name: str) -> bool: if hint_type in self._compiled_hints: return any(h.search(col_name) for h in self._compiled_hints[hint_type]) else: @@ -424,6 +425,7 @@ def _add_standard_hints(self) -> None: def _configure_normalizers(self) -> None: if not self._normalizers_config: # create default normalizer config + # TODO: pass default normalizers as context or as config with defaults self._normalizers_config = utils.default_normalizers() # import desired modules naming_module = import_module(self._normalizers_config["names"]) @@ -431,7 +433,7 @@ def _configure_normalizers(self) -> None: # name normalization functions self.normalize_table_name = naming_module.normalize_table_name self.normalize_column_name = naming_module.normalize_column_name - self.normalize_schema_name = utils.normalize_schema_name + self.normalize_schema_name = naming_module.normalize_schema_name self.normalize_make_dataset_name = naming_module.normalize_make_dataset_name self.normalize_make_path = naming_module.normalize_make_path self.normalize_break_path = naming_module.normalize_break_path @@ -439,6 +441,10 @@ def _configure_normalizers(self) -> None: self.normalize_data_item = json_module.normalize_data_item json_module.extend_schema(self) + def _verify_schema_name(self, name: str) -> None: + if name != self.normalize_schema_name(name): + raise InvalidSchemaName(name, self.normalize_schema_name(name)) + def _compile_regexes(self) -> None: if self._settings: for pattern, dt in self._settings.get("preferred_types", {}).items(): @@ -446,14 +452,14 @@ def _compile_regexes(self) -> None: self._compiled_preferred_types.append((utils.compile_simple_regex(pattern), dt)) for hint_name, hint_list in self._settings.get("default_hints", {}).items(): # compile hints which are column matching regexes - self._compiled_hints[hint_name] = list(map(lambda hint: utils.compile_simple_regex(hint), hint_list)) + self._compiled_hints[hint_name] = list(map(utils.compile_simple_regex, hint_list)) if self._schema_tables: for table in self._schema_tables.values(): if "filters" in table: if "excludes" in table["filters"]: - self._compiled_excludes[table["name"]] = list(map(lambda exclude: utils.compile_simple_regex(exclude), table["filters"]["excludes"])) + self._compiled_excludes[table["name"]] = list(map(utils.compile_simple_regex, table["filters"]["excludes"])) if "includes" in table["filters"]: - self._compiled_includes[table["name"]] = list(map(lambda exclude: utils.compile_simple_regex(exclude), table["filters"]["includes"])) + self._compiled_includes[table["name"]] = list(map(utils.compile_simple_regex, table["filters"]["includes"])) def __repr__(self) -> str: return f"Schema {self.name} at {id(self)}" diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index c587b864b6..f7078e6e7a 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -4,7 +4,7 @@ TDataType = Literal["text", "double", "bool", "timestamp", "bigint", "binary", "complex", "decimal", "wei"] -THintType = Literal["not_null", "partition", "cluster", "primary_key", "foreign_key", "sort", "unique"] +TColumnHint = Literal["not_null", "partition", "cluster", "primary_key", "foreign_key", "sort", "unique"] TColumnProp = Literal["name", "data_type", "nullable", "partition", "cluster", "primary_key", "foreign_key", "sort", "unique"] TWriteDisposition = Literal["skip", "append", "replace", "merge"] TTypeDetections = Literal["timestamp", "iso_timestamp", "large_integer", "hexbytes_to_text", "wei_to_double"] @@ -12,7 +12,7 @@ DATA_TYPES: Set[TDataType] = set(get_args(TDataType)) COLUMN_PROPS: Set[TColumnProp] = set(get_args(TColumnProp)) -COLUMN_HINTS: Set[THintType] = set(["partition", "cluster", "primary_key", "foreign_key", "sort", "unique"]) +COLUMN_HINTS: Set[TColumnHint] = set(["partition", "cluster", "primary_key", "foreign_key", "sort", "unique"]) WRITE_DISPOSITIONS: Set[TWriteDisposition] = set(get_args(TWriteDisposition)) @@ -74,7 +74,7 @@ class TNormalizersConfig(TypedDict, total=True): class TSchemaSettings(TypedDict, total=False): schema_sealed: Optional[bool] - default_hints: Optional[Dict[THintType, List[TSimpleRegex]]] + default_hints: Optional[Dict[TColumnHint, List[TSimpleRegex]]] preferred_types: Optional[Dict[TSimpleRegex, TDataType]] diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index ae2b0b0e07..4a7d2f4f2f 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -5,6 +5,7 @@ import datetime # noqa: I251 import contextlib from copy import deepcopy +from collections.abc import Mapping as C_Mapping, Sequence as C_Sequence from typing import Dict, List, Sequence, Tuple, Type, Any, cast from dlt.common import pendulum, json, Decimal, Wei @@ -18,8 +19,8 @@ from dlt.common.utils import str2bool from dlt.common.validation import TCustomValidator, validate_dict from dlt.common.schema import detections -from dlt.common.schema.typing import SIMPLE_REGEX_PREFIX, TColumnName, TNormalizersConfig, TSimpleRegex, TStoredSchema, TTableSchema, TTableSchemaColumns, TColumnSchemaBase, TColumnSchema, TColumnProp, TDataType, THintType, TTypeDetectionFunc, TTypeDetections, TWriteDisposition -from dlt.common.schema.exceptions import ParentTableNotFoundException, SchemaEngineNoUpgradePathException +from dlt.common.schema.typing import SIMPLE_REGEX_PREFIX, TColumnName, TNormalizersConfig, TPartialTableSchema, TSimpleRegex, TStoredSchema, TTableSchema, TTableSchemaColumns, TColumnSchemaBase, TColumnSchema, TColumnProp, TDataType, TColumnHint, TTypeDetectionFunc, TTypeDetections, TWriteDisposition +from dlt.common.schema.exceptions import CannotCoerceColumnException, ParentTableNotFoundException, SchemaEngineNoUpgradePathException, SchemaException, TablePropertiesConflictException RE_LEADING_DIGITS = re.compile(r"^\d+") @@ -28,19 +29,6 @@ DEFAULT_WRITE_DISPOSITION: TWriteDisposition = "append" -# fix a name so it is acceptable as schema name -def normalize_schema_name(name: str) -> str: - # empty and None schema names are not allowed - if not name: - raise ValueError(name) - - # prefix the name starting with digits - if RE_LEADING_DIGITS.match(name): - name = "s" + name - # leave only alphanumeric - return RE_NON_ALPHANUMERIC.sub("", name).lower() - - def apply_defaults(stored_schema: TStoredSchema) -> None: for table_name, table in stored_schema["tables"].items(): # overwrite name @@ -95,13 +83,13 @@ def generate_version_hash(stored_schema: TStoredSchema) -> str: content = json.dumps(schema_copy, sort_keys=True) h = hashlib.sha3_256(content.encode("utf-8")) # additionally check column order - table_names = sorted(schema_copy.get("tables", {}).keys()) + table_names = sorted((schema_copy.get("tables") or {}).keys()) if table_names: for tn in table_names: t = schema_copy["tables"][tn] h.update(tn.encode("utf-8")) # add column names to hash in order - for cn in t.get("columns", {}).keys(): + for cn in (t.get("columns") or {}).keys(): h.update(cn.encode("utf-8")) return base64.b64encode(h.digest()).decode('ascii') @@ -193,7 +181,7 @@ def upgrade_engine_version(schema_dict: DictStrAny, from_engine: int, to_engine: } } # move settings, convert strings to simple regexes - d_h: Dict[THintType, List[TSimpleRegex]] = schema_dict.pop("hints", {}) + d_h: Dict[TColumnHint, List[TSimpleRegex]] = schema_dict.pop("hints", {}) for h_k, h_l in d_h.items(): d_h[h_k] = list(map(lambda r: TSimpleRegex("re:" + r), h_l)) p_t: Dict[TSimpleRegex, TDataType] = schema_dict.pop("preferred_types", {}) @@ -281,26 +269,42 @@ def autodetect_sc_type(detection_fs: Sequence[TTypeDetections], t: Type[Any], v: def py_type_to_sc_type(t: Type[Any]) -> TDataType: - if issubclass(t, float): + # start with most popular types + if t is str: + return "text" + if t is float: return "double" # bool is subclass of int so must go first - elif t is bool: + if t is bool: return "bool" - elif issubclass(t, int): + if t is int: return "bigint" - elif issubclass(t, bytes): - return "binary" - elif issubclass(t, (dict, list)): + if issubclass(t, (dict, list)): return "complex" + + # those are special types that will not be present in json loaded dict # wei is subclass of decimal and must be checked first - elif issubclass(t, Wei): + if issubclass(t, Wei): return "wei" - elif issubclass(t, Decimal): + if issubclass(t, Decimal): return "decimal" - elif issubclass(t, datetime.datetime): + # TODO: implement new "date" type, currently assign "datetime" + if issubclass(t, (datetime.datetime, datetime.date)): return "timestamp" - else: + + # check again for subclassed basic types + if issubclass(t, str): return "text" + if issubclass(t, float): + return "double" + if issubclass(t, int): + return "bigint" + if issubclass(t, bytes): + return "binary" + if issubclass(t, (C_Mapping, C_Sequence)): + return "complex" + + raise TypeError(t) def coerce_type(to_type: TDataType, from_type: TDataType, value: Any) -> Any: @@ -432,11 +436,62 @@ def coerce_type(to_type: TDataType, from_type: TDataType, value: Any) -> Any: raise ValueError(value) -def compare_columns(a: TColumnSchema, b: TColumnSchema) -> bool: +def diff_tables(tab_a: TTableSchema, tab_b: TTableSchema, ignore_table_name: bool = True) -> TPartialTableSchema: + table_name = tab_a["name"] + if not ignore_table_name and table_name != tab_b["name"]: + raise TablePropertiesConflictException(table_name, "name", table_name, tab_b["name"]) + + # check if table properties can be merged + if tab_a.get("parent") != tab_b.get("parent"): + raise TablePropertiesConflictException(table_name, "parent", tab_a.get("parent"), tab_b.get("parent")) + # check if partial table has write disposition set + partial_w_d = tab_b.get("write_disposition") + if partial_w_d: + existing_w_d = tab_a.get("write_disposition") + if existing_w_d != partial_w_d: + raise TablePropertiesConflictException(table_name, "write_disposition", existing_w_d, partial_w_d) + + # get new columns, changes in the column data type or other properties are not allowed + table_columns = tab_a["columns"] + new_columns: List[TColumnSchema] = [] + for column in tab_b["columns"].values(): + column_name = column["name"] + if column_name in table_columns: + # we do not support changing existing columns + if not compare_column(table_columns[column_name], column): + # attempt to update to incompatible columns + raise CannotCoerceColumnException(table_name, column_name, column["data_type"], table_columns[column_name]["data_type"], None) + else: + new_columns.append(column) + + # TODO: compare filters, description etc. + + # return partial table containing only name and properties that differ (column, filters etc.) + return new_table(table_name, columns=new_columns) + + +def compare_tables(tab_a: TTableSchema, tab_b: TTableSchema) -> bool: + try: + diff_table = diff_tables(tab_a, tab_b, ignore_table_name=False) + # columns cannot differ + return len(diff_table["columns"]) == 0 + except SchemaException: + return False + + +def merge_tables(table: TTableSchema, partial_table: TPartialTableSchema) -> TTableSchema: + # merges "partial_table" into "table", preserving the "table" name + diff_table = diff_tables(table, partial_table, ignore_table_name=True) + # add new columns when all checks passed + table["columns"].update(diff_table["columns"]) + return table + + +def compare_column(a: TColumnSchema, b: TColumnSchema) -> bool: return a["data_type"] == b["data_type"] and a["nullable"] == b["nullable"] -def hint_to_column_prop(h: THintType) -> TColumnProp: +def hint_to_column_prop(h: TColumnHint) -> TColumnProp: if h == "not_null": return "nullable" return h @@ -490,20 +545,40 @@ def load_table() -> TTableSchema: return table -def new_table(table_name: str, parent_name: str = None, write_disposition: TWriteDisposition = None, columns: Sequence[TColumnSchema] = None) -> TTableSchema: +def new_table( + table_name: str, + parent_table_name: str = None, + write_disposition: TWriteDisposition = None, + columns: Sequence[TColumnSchema] = None, + validate_schema: bool = False +) -> TTableSchema: + table: TTableSchema = { "name": table_name, "columns": {} if columns is None else {c["name"]: add_missing_hints(c) for c in columns} } - if parent_name: - table["parent"] = parent_name + if parent_table_name: + table["parent"] = parent_table_name assert write_disposition is None else: # set write disposition only for root tables table["write_disposition"] = write_disposition or DEFAULT_WRITE_DISPOSITION + if validate_schema: + validate_dict(TTableSchema, table, f"new_table/{table_name}") return table +def new_column(column_name: str, data_type: TDataType, nullable: bool = True, validate_schema: bool = False) -> TColumnSchema: + column = add_missing_hints({ + "name": column_name, + "data_type": data_type, + "nullable": nullable + }) + if validate_schema: + validate_dict(TColumnSchema, column, f"new_column/{column_name}") + return column + + def default_normalizers() -> TNormalizersConfig: return { "detections": ["timestamp", "iso_timestamp"], @@ -514,5 +589,5 @@ def default_normalizers() -> TNormalizersConfig: } -def standard_hints() -> Dict[THintType, List[TSimpleRegex]]: +def standard_hints() -> Dict[TColumnHint, List[TSimpleRegex]]: return None diff --git a/dlt/common/signals.py b/dlt/common/signals.py index 2202bcb502..7a06d66d0f 100644 --- a/dlt/common/signals.py +++ b/dlt/common/signals.py @@ -12,16 +12,16 @@ exit_event = Event() -def signal_receiver(signal: int, frame: Any) -> None: +def signal_receiver(sig: int, frame: Any) -> None: global _received_signal - logger.info(f"Signal {signal} received") + logger.info(f"Signal {sig} received") if _received_signal > 0: logger.info(f"Another signal received after {_received_signal}") return - _received_signal = signal + _received_signal = sig # awake all threads sleeping on event exit_event.set() diff --git a/dlt/common/sources.py b/dlt/common/sources.py deleted file mode 100644 index 0ae2bc4c48..0000000000 --- a/dlt/common/sources.py +++ /dev/null @@ -1,82 +0,0 @@ -from collections import abc -from functools import wraps -from typing import Any, Callable, Optional, Sequence, TypeVar, Union, TypedDict -try: - from typing_extensions import ParamSpec -except ImportError: - ParamSpec = lambda x: [x] # type: ignore - -from dlt.common import logger -from dlt.common.time import sleep -from dlt.common.typing import StrAny, TDataItem - - -# possible types of items yielded by the source -# 1. document (mapping from str to any type) -# 2. Iterable (ie list) on the mapping above for returning many documents with single yield -TItem = Union[TDataItem, Sequence[TDataItem]] -TBoundItem = TypeVar("TBoundItem", bound=TItem) -TDeferred = Callable[[], TBoundItem] - -_TFunParams = ParamSpec("_TFunParams") - -# name of dlt metadata as part of the item -DLT_METADATA_FIELD = "_dlt_meta" - - -class TEventDLTMeta(TypedDict, total=False): - table_name: str # a root table in which store the event - - -def append_dlt_meta(item: TBoundItem, name: str, value: Any) -> TBoundItem: - if isinstance(item, abc.Sequence): - for i in item: - i.setdefault(DLT_METADATA_FIELD, {})[name] = value - elif isinstance(item, dict): - item.setdefault(DLT_METADATA_FIELD, {})[name] = value - - return item - - -def with_table_name(item: TBoundItem, table_name: str) -> TBoundItem: - # normalize table name before adding - return append_dlt_meta(item, "table_name", table_name) - - -def get_table_name(item: StrAny) -> Optional[str]: - if DLT_METADATA_FIELD in item: - meta: TEventDLTMeta = item[DLT_METADATA_FIELD] - return meta.get("table_name", None) - return None - - -def with_retry(max_retries: int = 3, retry_sleep: float = 1.0) -> Callable[[Callable[_TFunParams, TBoundItem]], Callable[_TFunParams, TBoundItem]]: - - def decorator(f: Callable[_TFunParams, TBoundItem]) -> Callable[_TFunParams, TBoundItem]: - - def _wrap(*args: Any, **kwargs: Any) -> TBoundItem: - attempts = 0 - while True: - try: - return f(*args, **kwargs) - except Exception as exc: - if attempts == max_retries: - raise - attempts += 1 - logger.warning(f"Exception {exc} in iterator, retrying {attempts} / {max_retries}") - sleep(retry_sleep) - - return _wrap - - return decorator - - -def defer_iterator(f: Callable[_TFunParams, TBoundItem]) -> Callable[_TFunParams, TDeferred[TBoundItem]]: - - @wraps(f) - def _wrap(*args: Any, **kwargs: Any) -> TDeferred[TBoundItem]: - def _curry() -> TBoundItem: - return f(*args, **kwargs) - return _curry - - return _wrap diff --git a/dlt/common/storages/__init__.py b/dlt/common/storages/__init__.py index e82ac29eb2..68d8c4aea4 100644 --- a/dlt/common/storages/__init__.py +++ b/dlt/common/storages/__init__.py @@ -1,2 +1,7 @@ +from .file_storage import FileStorage # noqa: F401 from .schema_storage import SchemaStorage # noqa: F401 from .live_schema_storage import LiveSchemaStorage # noqa: F401 +from .normalize_storage import NormalizeStorage # noqa: F401 +from .versioned_storage import VersionedStorage # noqa: F401 +from .load_storage import LoadStorage # noqa: F401 +from .data_item_storage import DataItemStorage # noqa: F401 diff --git a/dlt/common/storages/data_item_storage.py b/dlt/common/storages/data_item_storage.py new file mode 100644 index 0000000000..27a4a688b1 --- /dev/null +++ b/dlt/common/storages/data_item_storage.py @@ -0,0 +1,38 @@ +from typing import Dict, Any +from abc import ABC, abstractmethod + +from dlt.common import logger +from dlt.common.schema import TTableSchemaColumns +from dlt.common.typing import TDataItems +from dlt.common.data_writers import TLoaderFileFormat, BufferedDataWriter + + +class DataItemStorage(ABC): + def __init__(self, load_file_type: TLoaderFileFormat, *args: Any) -> None: + self.loader_file_format = load_file_type + self.buffered_writers: Dict[str, BufferedDataWriter] = {} + super().__init__(*args) + + def write_data_item(self, load_id: str, schema_name: str, table_name: str, item: TDataItems, columns: TTableSchemaColumns) -> None: + # unique writer id + writer_id = f"{load_id}.{schema_name}.{table_name}" + writer = self.buffered_writers.get(writer_id, None) + if not writer: + # assign a jsonl writer for each table + path = self._get_data_item_path_template(load_id, schema_name, table_name) + writer = BufferedDataWriter(self.loader_file_format, path) + self.buffered_writers[writer_id] = writer + # write item(s) + writer.write_data_item(item, columns) + + def close_writers(self, extract_id: str) -> None: + # flush and close all files + for name, writer in self.buffered_writers.items(): + if name.startswith(extract_id): + logger.debug(f"Closing writer for {name} with file {writer._file} and actual name {writer._file_name}") + writer.close_writer() + + @abstractmethod + def _get_data_item_path_template(self, load_id: str, schema_name: str, table_name: str) -> str: + # note: use %s for file id to create required template format + pass diff --git a/dlt/common/storages/exceptions.py b/dlt/common/storages/exceptions.py index 162fc76a93..4f5d2e3551 100644 --- a/dlt/common/storages/exceptions.py +++ b/dlt/common/storages/exceptions.py @@ -2,7 +2,7 @@ from typing import Iterable from dlt.common.exceptions import DltException -from dlt.common.dataset_writers import TLoaderFileFormat +from dlt.common.data_writers import TLoaderFileFormat class StorageException(DltException): @@ -32,10 +32,11 @@ class LoaderStorageException(StorageException): class JobWithUnsupportedWriterException(LoaderStorageException): - def __init__(self, load_id: str, expected_file_format: Iterable[TLoaderFileFormat], wrong_job: str) -> None: + def __init__(self, load_id: str, expected_file_formats: Iterable[TLoaderFileFormat], wrong_job: str) -> None: self.load_id = load_id - self.expected_file_format = expected_file_format + self.expected_file_formats = expected_file_formats self.wrong_job = wrong_job + super().__init__(f"Job {wrong_job} for load id {load_id} requires loader file format that is not one of {expected_file_formats}") class SchemaStorageException(StorageException): diff --git a/dlt/common/file_storage.py b/dlt/common/storages/file_storage.py similarity index 69% rename from dlt/common/file_storage.py rename to dlt/common/storages/file_storage.py index c626d4af2c..8947515a42 100644 --- a/dlt/common/file_storage.py +++ b/dlt/common/storages/file_storage.py @@ -1,7 +1,7 @@ import os import tempfile import shutil -from pathlib import Path +import pathvalidate from typing import IO, Any, List from dlt.common.utils import encoding_for_mode @@ -18,10 +18,6 @@ def __init__(self, if makedirs: os.makedirs(storage_path, exist_ok=True) - @classmethod - def from_file(cls, file_path: str, file_type: str = "t",) -> "FileStorage": - return cls(os.path.dirname(file_path), file_type) - def save(self, relative_path: str, data: Any) -> str: return self.save_atomic(self.storage_path, relative_path, data, file_type=self.file_type) @@ -47,14 +43,14 @@ def load(self, relative_path: str) -> Any: return text_file.read() def delete(self, relative_path: str) -> None: - file_path = self._make_path(relative_path) + file_path = self.make_full_path(relative_path) if os.path.isfile(file_path): os.remove(file_path) else: raise FileNotFoundError(file_path) def delete_folder(self, relative_path: str, recursively: bool = False) -> None: - folder_path = self._make_path(relative_path) + folder_path = self.make_full_path(relative_path) if os.path.isdir(folder_path): if recursively: shutil.rmtree(folder_path) @@ -65,17 +61,17 @@ def delete_folder(self, relative_path: str, recursively: bool = False) -> None: def open_file(self, realtive_path: str, mode: str = "r") -> IO[Any]: mode = mode + self.file_type - return open(self._make_path(realtive_path), mode, encoding=encoding_for_mode(mode)) + return open(self.make_full_path(realtive_path), mode, encoding=encoding_for_mode(mode)) def open_temp(self, delete: bool = False, mode: str = "w", file_type: str = None) -> IO[Any]: mode = mode + file_type or self.file_type return tempfile.NamedTemporaryFile(dir=self.storage_path, mode=mode, delete=delete, encoding=encoding_for_mode(mode)) def has_file(self, relative_path: str) -> bool: - return os.path.isfile(self._make_path(relative_path)) + return os.path.isfile(self.make_full_path(relative_path)) def has_folder(self, relative_path: str) -> bool: - return os.path.isdir(self._make_path(relative_path)) + return os.path.isdir(self.make_full_path(relative_path)) def list_folder_files(self, relative_path: str, to_root: bool = True) -> List[str]: """List all files in ``relative_path`` folder @@ -87,7 +83,7 @@ def list_folder_files(self, relative_path: str, to_root: bool = True) -> List[st Returns: List[str]: A list of file names with optional path as per ``to_root`` parameter """ - scan_path = self._make_path(relative_path) + scan_path = self.make_full_path(relative_path) if to_root: # list files in relative path, returning paths relative to storage root return [os.path.join(relative_path, e.name) for e in os.scandir(scan_path) if e.is_file()] @@ -97,7 +93,7 @@ def list_folder_files(self, relative_path: str, to_root: bool = True) -> List[st def list_folder_dirs(self, relative_path: str, to_root: bool = True) -> List[str]: # list content of relative path, returning paths relative to storage root - scan_path = self._make_path(relative_path) + scan_path = self.make_full_path(relative_path) if to_root: # list folders in relative path, returning paths relative to storage root return [os.path.join(relative_path, e.name) for e in os.scandir(scan_path) if e.is_dir()] @@ -106,25 +102,19 @@ def list_folder_dirs(self, relative_path: str, to_root: bool = True) -> List[str return [e.name for e in os.scandir(scan_path) if e.is_dir()] def create_folder(self, relative_path: str, exists_ok: bool = False) -> None: - os.makedirs(self._make_path(relative_path), exist_ok=exists_ok) - - def copy_cross_storage_atomically(self, dest_volume_root: str, dest_relative_path: str, source_path: str, dest_name: str) -> None: - external_tmp_file = tempfile.mktemp(dir=dest_volume_root) - # first copy to temp file - shutil.copy(self._make_path(source_path), external_tmp_file) - # then rename to dest name - external_dest = os.path.join(dest_volume_root, dest_relative_path, dest_name) - try: - os.rename(external_tmp_file, external_dest) - except Exception: - if os.path.isfile(external_tmp_file): - os.remove(external_tmp_file) - raise + os.makedirs(self.make_full_path(relative_path), exist_ok=exists_ok) + + def link_hard(self, from_relative_path: str, to_relative_path: str) -> None: + # note: some interesting stuff on links https://lightrun.com/answers/conan-io-conan-research-investigate-symlinks-and-hard-links + os.link( + self.make_full_path(from_relative_path), + self.make_full_path(to_relative_path) + ) def atomic_rename(self, from_relative_path: str, to_relative_path: str) -> None: os.rename( - self._make_path(from_relative_path), - self._make_path(to_relative_path) + self.make_full_path(from_relative_path), + self.make_full_path(to_relative_path) ) def in_storage(self, path: str) -> bool: @@ -138,11 +128,26 @@ def to_relative_path(self, path: str) -> str: raise ValueError(path) return os.path.relpath(path, start=self.storage_path) - def get_file_stem(self, path: str) -> str: - return Path(os.path.basename(path)).stem + def make_full_path(self, path: str) -> str: + # try to make a relative path if paths are absolute or overlapping + try: + path = self.to_relative_path(path) + except ValueError: + # if path is absolute and cannot be made relative to the storage then cannot be made full path with storage root + if os.path.isabs(path): + raise ValueError(path) - def get_file_name(self, path: str) -> str: - return Path(path).name + # then assume that it is a path relative to storage root + return os.path.join(self.storage_path, path) - def _make_path(self, relative_path: str) -> str: - return os.path.join(self.storage_path, relative_path) + @staticmethod + def get_file_name_from_file_path(file_path: str) -> str: + return os.path.basename(file_path) + + @staticmethod + def validate_file_name_component(name: str) -> None: + # Universal platform bans several characters allowed in POSIX ie. | < \ or "COM1" :) + pathvalidate.validate_filename(name, platform="Universal") + # component cannot contain "." + if "." in name: + raise pathvalidate.error.InvalidCharError(reason="Component name cannot contain . (dots)") diff --git a/dlt/common/storages/live_schema_storage.py b/dlt/common/storages/live_schema_storage.py index 3c1a131f09..af7ce33af9 100644 --- a/dlt/common/storages/live_schema_storage.py +++ b/dlt/common/storages/live_schema_storage.py @@ -1,14 +1,24 @@ -from typing import Dict, Type +from typing import Any, Dict, overload -from dlt.common.configuration import SchemaVolumeConfiguration +from dlt.common.typing import ConfigValue from dlt.common.schema.schema import Schema from dlt.common.storages.schema_storage import SchemaStorage +from dlt.common.configuration.specs import SchemaVolumeConfiguration class LiveSchemaStorage(SchemaStorage): - def __init__(self, C: Type[SchemaVolumeConfiguration], makedirs: bool = False) -> None: + + @overload + def __init__(self, config: SchemaVolumeConfiguration, makedirs: bool = False) -> None: + ... + + @overload + def __init__(self, config: SchemaVolumeConfiguration = ConfigValue, makedirs: bool = False) -> None: + ... + + def __init__(self, config: SchemaVolumeConfiguration = None, makedirs: bool = False) -> None: self.live_schemas: Dict[str, Schema] = {} - super().__init__(C, makedirs) + super().__init__(config, makedirs) def __getitem__(self, name: str) -> Schema: # disconnect live schema @@ -26,19 +36,26 @@ def load_schema(self, name: str) -> Schema: def save_schema(self, schema: Schema) -> str: rv = super().save_schema(schema) - # update the live schema with schema being saved but to not create live instance if not already present + # update the live schema with schema being saved but do not create live instance if not already present self._update_live_schema(schema, False) return rv + def initialize_import_if_new(self, schema: Schema) -> None: + if self.config.import_schema_path and schema.name not in self: + try: + self._load_import_schema(schema.name) + except FileNotFoundError: + # save import schema only if it not exist + self._export_schema(schema, self.config.import_schema_path) + def commit_live_schema(self, name: str) -> Schema: # if live schema exists and is modified then it must be used as an import schema live_schema = self.live_schemas.get(name) if live_schema and live_schema.stored_version_hash != live_schema.version_hash: - print("bumping and saving") live_schema.bump_version() - if self.C.IMPORT_SCHEMA_PATH: + if self.config.import_schema_path: # overwrite import schemas if specified - self._export_schema(live_schema, self.C.IMPORT_SCHEMA_PATH) + self._export_schema(live_schema, self.config.import_schema_path) else: # write directly to schema storage if no import schema folder configured self._save_schema(live_schema) diff --git a/dlt/common/storages/load_storage.py b/dlt/common/storages/load_storage.py index ee2b74b481..74246f60b5 100644 --- a/dlt/common/storages/load_storage.py +++ b/dlt/common/storages/load_storage.py @@ -1,17 +1,18 @@ import os from os.path import join from pathlib import Path -from typing import Iterable, NamedTuple, Literal, Optional, Sequence, Set, Tuple, Type, get_args +from typing import Iterable, NamedTuple, Literal, Optional, Sequence, Set, get_args, overload from dlt.common import json, pendulum -from dlt.common.file_storage import FileStorage -from dlt.common.dataset_writers import TLoaderFileFormat, write_jsonl, write_insert_values -from dlt.common.configuration import LoadVolumeConfiguration +from dlt.common.configuration.inject import with_config +from dlt.common.typing import ConfigValue, DictStrAny, StrAny +from dlt.common.storages.file_storage import FileStorage +from dlt.common.data_writers import TLoaderFileFormat, DataWriter +from dlt.common.configuration.specs import LoadVolumeConfiguration from dlt.common.exceptions import TerminalValueError from dlt.common.schema import Schema, TSchemaUpdate, TTableSchemaColumns from dlt.common.storages.versioned_storage import VersionedStorage -from dlt.common.typing import DictStrAny, StrAny - +from dlt.common.storages.data_item_storage import DataItemStorage from dlt.common.storages.exceptions import JobWithUnsupportedWriterException @@ -24,7 +25,7 @@ class TParsedJobFileName(NamedTuple): file_format: TLoaderFileFormat -class LoadStorage(VersionedStorage): +class LoadStorage(DataItemStorage, VersionedStorage): STORAGE_VERSION = "1.0.0" NORMALIZED_FOLDER = "normalized" # folder within the volume where load packages are stored @@ -41,21 +42,33 @@ class LoadStorage(VersionedStorage): ALL_SUPPORTED_FILE_FORMATS: Set[TLoaderFileFormat] = set(get_args(TLoaderFileFormat)) + @overload + def __init__(self, is_owner: bool, preferred_file_format: TLoaderFileFormat, supported_file_formats: Iterable[TLoaderFileFormat], config: LoadVolumeConfiguration) -> None: + ... + + @overload + def __init__(self, is_owner: bool, preferred_file_format: TLoaderFileFormat, supported_file_formats: Iterable[TLoaderFileFormat], config: LoadVolumeConfiguration = ConfigValue) -> None: + ... + + @with_config(spec=LoadVolumeConfiguration, namespaces=("load",)) def __init__( self, is_owner: bool, - C: Type[LoadVolumeConfiguration], preferred_file_format: TLoaderFileFormat, - supported_file_formats: Iterable[TLoaderFileFormat] + supported_file_formats: Iterable[TLoaderFileFormat], + config: LoadVolumeConfiguration = ConfigValue ) -> None: if not LoadStorage.ALL_SUPPORTED_FILE_FORMATS.issuperset(supported_file_formats): raise TerminalValueError(supported_file_formats) if preferred_file_format not in supported_file_formats: raise TerminalValueError(preferred_file_format) - self.preferred_file_format = preferred_file_format self.supported_file_formats = supported_file_formats - self.delete_completed_jobs = C.DELETE_COMPLETED_JOBS - super().__init__(LoadStorage.STORAGE_VERSION, is_owner, FileStorage(C.LOAD_VOLUME_PATH, "t", makedirs=is_owner)) + self.config = config + super().__init__( + preferred_file_format, + LoadStorage.STORAGE_VERSION, + is_owner, FileStorage(config.load_volume_path, "t", makedirs=is_owner) + ) if is_owner: self.initialize_storage() @@ -74,13 +87,15 @@ def create_temp_load_package(self, load_id: str) -> None: self.storage.create_folder(join(load_id, LoadStorage.FAILED_JOBS_FOLDER)) self.storage.create_folder(join(load_id, LoadStorage.STARTED_JOBS_FOLDER)) + def _get_data_item_path_template(self, load_id: str, _: str, table_name: str) -> str: + file_name = self.build_job_file_name(table_name, "%s", with_extension=False) + return self.storage.make_full_path(join(load_id, LoadStorage.NEW_JOBS_FOLDER, file_name)) + def write_temp_job_file(self, load_id: str, table_name: str, table: TTableSchemaColumns, file_id: str, rows: Sequence[StrAny]) -> str: - file_name = self.build_job_file_name(table_name, file_id) - with self.storage.open_file(join(load_id, LoadStorage.NEW_JOBS_FOLDER, file_name), mode="w") as f: - if self.preferred_file_format == "jsonl": - write_jsonl(f, rows) - elif self.preferred_file_format == "insert_values": - write_insert_values(f, rows, table.keys()) + file_name = self._get_data_item_path_template(load_id, None, table_name) % file_id + "." + self.loader_file_format + with self.storage.open_file(file_name, mode="w") as f: + writer = DataWriter.from_file_format(self.loader_file_format, f) + writer.write_all(table, rows) return Path(file_name).name def load_package_schema(self, load_id: str) -> Schema: @@ -176,7 +191,7 @@ def complete_load_package(self, load_id: str) -> None: load_path = self.get_package_path(load_id) has_failed_jobs = len(self.list_failed_jobs(load_id)) > 0 # delete load that does not contain failed jobs - if self.delete_completed_jobs and not has_failed_jobs: + if self.config.delete_completed_jobs and not has_failed_jobs: self.storage.delete_folder(load_path, recursively=True) else: completed_path = self.get_completed_package_path(load_id) @@ -188,13 +203,6 @@ def get_package_path(self, load_id: str) -> str: def get_completed_package_path(self, load_id: str) -> str: return join(LoadStorage.LOADED_FOLDER, load_id) - def build_job_file_name(self, table_name: str, file_id: str, retry_count: int = 0) -> str: - if "." in table_name: - raise ValueError(table_name) - if "." in file_id: - raise ValueError(file_id) - return f"{table_name}.{file_id}.{int(retry_count)}.{self.preferred_file_format}" - def job_elapsed_time_seconds(self, file_path: str) -> float: return pendulum.now().timestamp() - os.path.getmtime(file_path) # type: ignore @@ -211,7 +219,7 @@ def _move_job(self, load_id: str, source_folder: TWorkingFolder, dest_folder: TW load_path = self.get_package_path(load_id) dest_path = join(load_path, dest_folder, new_file_name or file_name) self.storage.atomic_rename(join(load_path, source_folder, file_name), dest_path) - return self.storage._make_path(dest_path) + return self.storage.make_full_path(dest_path) def _get_job_folder_path(self, load_id: str, folder: TWorkingFolder) -> str: return join(self.get_package_path(load_id), folder) @@ -219,6 +227,15 @@ def _get_job_folder_path(self, load_id: str, folder: TWorkingFolder) -> str: def _get_job_file_path(self, load_id: str, folder: TWorkingFolder, file_name: str) -> str: return join(self._get_job_folder_path(load_id, folder), file_name) + def build_job_file_name(self, table_name: str, file_id: str, retry_count: int = 0, validate_components: bool = True, with_extension: bool = True) -> str: + if validate_components: + FileStorage.validate_file_name_component(table_name) + FileStorage.validate_file_name_component(file_id) + fn = f"{table_name}.{file_id}.{int(retry_count)}" + if with_extension: + return fn + f".{self.loader_file_format}" + return fn + @staticmethod def parse_job_file_name(file_name: str) -> TParsedJobFileName: p = Path(file_name) diff --git a/dlt/common/storages/normalize_storage.py b/dlt/common/storages/normalize_storage.py index fd52ae72e5..b05c3adf2a 100644 --- a/dlt/common/storages/normalize_storage.py +++ b/dlt/common/storages/normalize_storage.py @@ -1,22 +1,37 @@ -from typing import List, Sequence, Tuple, Type +from typing import ClassVar, Sequence, NamedTuple, overload from itertools import groupby from pathlib import Path -from dlt.common.utils import chunks -from dlt.common.file_storage import FileStorage -from dlt.common.configuration import NormalizeVolumeConfiguration +from dlt.common.storages.file_storage import FileStorage +from dlt.common.configuration import with_config +from dlt.common.configuration.specs import NormalizeVolumeConfiguration from dlt.common.storages.versioned_storage import VersionedStorage +from dlt.common.typing import ConfigValue + + +class TParsedNormalizeFileName(NamedTuple): + schema_name: str + table_name: str + file_id: str class NormalizeStorage(VersionedStorage): - STORAGE_VERSION = "1.0.0" - EXTRACTED_FOLDER: str = "extracted" # folder within the volume where extracted files to be normalized are stored - EXTRACTED_FILE_EXTENSION = ".extracted.json" - EXTRACTED_FILE_EXTENSION_LEN = len(EXTRACTED_FILE_EXTENSION) + STORAGE_VERSION: ClassVar[str] = "1.0.0" + EXTRACTED_FOLDER: ClassVar[str] = "extracted" # folder within the volume where extracted files to be normalized are stored + + @overload + def __init__(self, is_owner: bool, config: NormalizeVolumeConfiguration) -> None: + ... - def __init__(self, is_owner: bool, C: Type[NormalizeVolumeConfiguration]) -> None: - super().__init__(NormalizeStorage.STORAGE_VERSION, is_owner, FileStorage(C.NORMALIZE_VOLUME_PATH, "t", makedirs=is_owner)) + @overload + def __init__(self, is_owner: bool, config: NormalizeVolumeConfiguration = ConfigValue) -> None: + ... + + @with_config(spec=NormalizeVolumeConfiguration, namespaces=("normalize",)) + def __init__(self, is_owner: bool, config: NormalizeVolumeConfiguration = ConfigValue) -> None: + super().__init__(NormalizeStorage.STORAGE_VERSION, is_owner, FileStorage(config.normalize_volume_path, "t", makedirs=is_owner)) + self.config = config if is_owner: self.initialize_storage() @@ -27,49 +42,24 @@ def list_files_to_normalize_sorted(self) -> Sequence[str]: return sorted(self.storage.list_folder_files(NormalizeStorage.EXTRACTED_FOLDER)) def get_grouped_iterator(self, files: Sequence[str]) -> "groupby[str, str]": - return groupby(files, lambda f: NormalizeStorage.get_schema_name(f)) - - @staticmethod - def chunk_by_events(files: Sequence[str], max_events: int, processing_cores: int) -> List[Sequence[str]]: - # should distribute ~ N events evenly among m cores with fallback for small amounts of events - - def count_events(file_name : str) -> int: - # return event count from file name - return NormalizeStorage.get_events_count(file_name) - - counts = list(map(count_events, files)) - # make a list of files containing ~max_events - events_count = 0 - m = 0 - while events_count < max_events and m < len(files): - events_count += counts[m] - m += 1 - processing_chunks = round(m / processing_cores) - if processing_chunks == 0: - # return one small chunk - return [files] - else: - # should return ~ amount of chunks to fill all the cores - return list(chunks(files[:m], processing_chunks)) - - @staticmethod - def get_events_count(file_name: str) -> int: - return NormalizeStorage._parse_extracted_file_name(file_name)[0] + return groupby(files, NormalizeStorage.get_schema_name) @staticmethod def get_schema_name(file_name: str) -> str: - return NormalizeStorage._parse_extracted_file_name(file_name)[2] + return NormalizeStorage.parse_normalize_file_name(file_name).schema_name @staticmethod - def build_extracted_file_name(schema_name: str, stem: str, event_count: int, load_id: str) -> str: + def build_extracted_file_stem(schema_name: str, table_name: str, file_id: str) -> str: # builds file name with the extracted data to be passed to normalize - return f"{schema_name}_{stem}_{load_id}_{event_count}{NormalizeStorage.EXTRACTED_FILE_EXTENSION}" + return f"{schema_name}.{table_name}.{file_id}" @staticmethod - def _parse_extracted_file_name(file_name: str) -> Tuple[int, str, str]: + def parse_normalize_file_name(file_name: str) -> TParsedNormalizeFileName: # parse extracted file name and returns (events found, load id, schema_name) - if not file_name.endswith(NormalizeStorage.EXTRACTED_FILE_EXTENSION): + if not file_name.endswith("jsonl"): raise ValueError(file_name) - parts = Path(file_name[:-NormalizeStorage.EXTRACTED_FILE_EXTENSION_LEN]).stem.split("_") - return (int(parts[-1]), parts[-2], parts[0]) \ No newline at end of file + parts = Path(file_name).stem.split(".") + if len(parts) != 3: + raise ValueError(file_name) + return TParsedNormalizeFileName(*parts) diff --git a/dlt/common/storages/schema_storage.py b/dlt/common/storages/schema_storage.py index 9fd1faa65f..0127fc2b6f 100644 --- a/dlt/common/storages/schema_storage.py +++ b/dlt/common/storages/schema_storage.py @@ -1,15 +1,14 @@ import os import re import yaml -from typing import Iterator, List, Type, Mapping +from typing import Iterator, List, Mapping, overload from dlt.common import json, logger -from dlt.common.configuration.schema_volume_configuration import TSchemaFileFormat -from dlt.common.file_storage import FileStorage +from dlt.common.configuration import with_config +from dlt.common.configuration.specs import SchemaVolumeConfiguration, TSchemaFileFormat +from dlt.common.storages.file_storage import FileStorage from dlt.common.schema import Schema, verify_schema_hash -from dlt.common.schema.typing import TStoredSchema -from dlt.common.typing import DictStrAny -from dlt.common.configuration import SchemaVolumeConfiguration +from dlt.common.typing import DictStrAny, ConfigValue from dlt.common.storages.exceptions import InStorageSchemaModified, SchemaNotFoundError @@ -17,11 +16,20 @@ class SchemaStorage(Mapping[str, Schema]): SCHEMA_FILE_NAME = "schema.%s" - NAMED_SCHEMA_FILE_PATTERN = f"%s_{SCHEMA_FILE_NAME}" + NAMED_SCHEMA_FILE_PATTERN = f"%s.{SCHEMA_FILE_NAME}" - def __init__(self, C: Type[SchemaVolumeConfiguration], makedirs: bool = False) -> None: - self.C = C - self.storage = FileStorage(C.SCHEMA_VOLUME_PATH, makedirs=makedirs) + @overload + def __init__(self, config: SchemaVolumeConfiguration, makedirs: bool = False) -> None: + ... + + @overload + def __init__(self, config: SchemaVolumeConfiguration = ConfigValue, makedirs: bool = False) -> None: + ... + + @with_config(spec=SchemaVolumeConfiguration, namespaces=("schema",)) + def __init__(self, config: SchemaVolumeConfiguration = ConfigValue, makedirs: bool = False) -> None: + self.config = config + self.storage = FileStorage(config.schema_volume_path, makedirs=makedirs) def load_schema(self, name: str) -> Schema: # loads a schema from a store holding many schemas @@ -31,21 +39,21 @@ def load_schema(self, name: str) -> Schema: storage_schema = json.loads(self.storage.load(schema_file)) # prevent external modifications of schemas kept in storage if not verify_schema_hash(storage_schema, empty_hash_verifies=True): - raise InStorageSchemaModified(name, self.C.SCHEMA_VOLUME_PATH) + raise InStorageSchemaModified(name, self.config.schema_volume_path) except FileNotFoundError: # maybe we can import from external storage pass # try to import from external storage - if self.C.IMPORT_SCHEMA_PATH: + if self.config.import_schema_path: return self._maybe_import_schema(name, storage_schema) if storage_schema is None: - raise SchemaNotFoundError(name, self.C.SCHEMA_VOLUME_PATH) + raise SchemaNotFoundError(name, self.config.schema_volume_path) return Schema.from_dict(storage_schema) def save_schema(self, schema: Schema) -> str: # check if there's schema to import - if self.C.IMPORT_SCHEMA_PATH: + if self.config.import_schema_path: try: imported_schema = Schema.from_dict(self._load_import_schema(schema.name)) # link schema being saved to current imported schema so it will not overwrite this save when loaded @@ -54,8 +62,8 @@ def save_schema(self, schema: Schema) -> str: # just save the schema pass path = self._save_schema(schema) - if self.C.EXPORT_SCHEMA_PATH: - self._export_schema(schema, self.C.EXPORT_SCHEMA_PATH) + if self.config.export_schema_path: + self._export_schema(schema, self.config.export_schema_path) return path def remove_schema(self, name: str) -> None: @@ -69,7 +77,7 @@ def has_schema(self, name: str) -> bool: def list_schemas(self) -> List[str]: files = self.storage.list_folder_files(".", to_root=False) # extract names - return [re.split("_|schema", f)[0] for f in files] + return [f.split(".")[0] for f in files] def __getitem__(self, name: str) -> Schema: return self.load_schema(name) @@ -112,37 +120,37 @@ def _maybe_import_schema(self, name: str, storage_schema: DictStrAny = None) -> except FileNotFoundError: # no schema to import -> skip silently and return the original if storage_schema is None: - raise SchemaNotFoundError(name, self.C.SCHEMA_VOLUME_PATH, self.C.IMPORT_SCHEMA_PATH, self.C.EXTERNAL_SCHEMA_FORMAT) + raise SchemaNotFoundError(name, self.config.schema_volume_path, self.config.import_schema_path, self.config.external_schema_format) rv_schema = Schema.from_dict(storage_schema) assert rv_schema is not None return rv_schema def _load_import_schema(self, name: str) -> DictStrAny: - import_storage = FileStorage(self.C.IMPORT_SCHEMA_PATH, makedirs=False) - schema_file = self._file_name_in_store(name, self.C.EXTERNAL_SCHEMA_FORMAT) + import_storage = FileStorage(self.config.import_schema_path, makedirs=False) + schema_file = self._file_name_in_store(name, self.config.external_schema_format) imported_schema: DictStrAny = None imported_schema_s = import_storage.load(schema_file) - if self.C.EXTERNAL_SCHEMA_FORMAT == "json": + if self.config.external_schema_format == "json": imported_schema = json.loads(imported_schema_s) - elif self.C.EXTERNAL_SCHEMA_FORMAT == "yaml": + elif self.config.external_schema_format == "yaml": imported_schema = yaml.safe_load(imported_schema_s) else: - raise ValueError(self.C.EXTERNAL_SCHEMA_FORMAT) + raise ValueError(self.config.external_schema_format) return imported_schema def _export_schema(self, schema: Schema, export_path: str) -> None: - if self.C.EXTERNAL_SCHEMA_FORMAT == "json": - exported_schema_s = schema.to_pretty_json(remove_defaults=self.C.EXTERNAL_SCHEMA_FORMAT_REMOVE_DEFAULTS) - elif self.C.EXTERNAL_SCHEMA_FORMAT == "yaml": - exported_schema_s = schema.to_pretty_yaml(remove_defaults=self.C.EXTERNAL_SCHEMA_FORMAT_REMOVE_DEFAULTS) + if self.config.external_schema_format == "json": + exported_schema_s = schema.to_pretty_json(remove_defaults=self.config.external_schema_format_remove_defaults) + elif self.config.external_schema_format == "yaml": + exported_schema_s = schema.to_pretty_yaml(remove_defaults=self.config.external_schema_format_remove_defaults) else: - raise ValueError(self.C.EXTERNAL_SCHEMA_FORMAT) + raise ValueError(self.config.external_schema_format) export_storage = FileStorage(export_path, makedirs=True) - schema_file = self._file_name_in_store(schema.name, self.C.EXTERNAL_SCHEMA_FORMAT) + schema_file = self._file_name_in_store(schema.name, self.config.external_schema_format) export_storage.save(schema_file, exported_schema_s) - logger.info(f"Schema {schema.name} exported to {export_path} with version {schema.stored_version} as {self.C.EXTERNAL_SCHEMA_FORMAT}") + logger.info(f"Schema {schema.name} exported to {export_path} with version {schema.stored_version} as {self.config.external_schema_format}") def _save_schema(self, schema: Schema) -> str: # save a schema to schema store diff --git a/dlt/common/storages/versioned_storage.py b/dlt/common/storages/versioned_storage.py index 9669e076e0..9dad05f9cc 100644 --- a/dlt/common/storages/versioned_storage.py +++ b/dlt/common/storages/versioned_storage.py @@ -1,6 +1,6 @@ import semver -from dlt.common.file_storage import FileStorage +from dlt.common.storages.file_storage import FileStorage from dlt.common.storages.exceptions import NoMigrationPathException, WrongStorageVersionException @@ -31,7 +31,7 @@ def __init__(self, version: semver.VersionInfo, is_owner: bool, storage: FileSto if is_owner: self._save_version(version) else: - raise WrongStorageVersionException(storage.storage_path, semver.VersionInfo.parse("0.0.0"), version) + raise WrongStorageVersionException(storage.storage_path, semver.VersionInfo.parse("0.0.0"), version) def migrate_storage(self, from_version: semver.VersionInfo, to_version: semver.VersionInfo) -> None: # migration example: diff --git a/dlt/common/typing.py b/dlt/common/typing.py index e22f2ae21a..b29d8a495d 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -1,31 +1,38 @@ from collections.abc import Mapping as C_Mapping, Sequence as C_Sequence from re import Pattern as _REPattern -from typing import Callable, Dict, Any, Literal, Mapping, NewType, Tuple, Type, TypeVar, Generic, Protocol, TYPE_CHECKING, Union, runtime_checkable, get_args, get_origin +from typing import Callable, Dict, Any, Literal, List, Mapping, NewType, Tuple, Type, TypeVar, Generic, Protocol, TYPE_CHECKING, Union, runtime_checkable, get_args, get_origin +from typing_extensions import TypeAlias, ParamSpec + if TYPE_CHECKING: from _typeshed import StrOrBytesPath + # from typing_extensions import ParamSpec from typing import _TypedDict REPattern = _REPattern[str] else: StrOrBytesPath = Any from typing import _TypedDictMeta as _TypedDict REPattern = _REPattern - -DictStrAny = Dict[str, Any] -DictStrStr = Dict[str, str] -StrAny = Mapping[str, Any] # immutable, covariant entity -StrStr = Mapping[str, str] # immutable, covariant entity -StrStrStr = Mapping[str, Mapping[str, str]] # immutable, covariant entity -TFun = TypeVar("TFun", bound=Callable[..., Any]) + # ParamSpec = lambda x: [x] + +DictStrAny: TypeAlias = Dict[str, Any] +DictStrStr: TypeAlias = Dict[str, str] +StrAny: TypeAlias = Mapping[str, Any] # immutable, covariant entity +StrStr: TypeAlias = Mapping[str, str] # immutable, covariant entity +StrStrStr: TypeAlias = Mapping[str, Mapping[str, str]] # immutable, covariant entity +AnyFun: TypeAlias = Callable[..., Any] +TFun = TypeVar("TFun", bound=AnyFun) # any function TAny = TypeVar("TAny", bound=Any) +TAnyClass = TypeVar("TAnyClass", bound=object) TSecretValue = NewType("TSecretValue", str) # represent secret value ie. coming from Kubernetes/Docker secrets or other providers -TDataItem = DictStrAny +TDataItem: TypeAlias = object # a single data item as extracted from data source +TDataItems: TypeAlias = Union[TDataItem, List[TDataItem]] # a single or many data items as extracted from the data source +ConfigValue: None = None # a value of type None indicating argument that may be injected by config provider TVariantBase = TypeVar("TVariantBase", covariant=True) TVariantRV = Tuple[str, Any] VARIANT_FIELD_FORMAT = "v_%s" - @runtime_checkable class SupportsVariant(Protocol, Generic[TVariantBase]): """Defines variant type protocol that should be recognized by normalizers @@ -38,10 +45,7 @@ def __call__(self) -> Union[TVariantBase, TVariantRV]: def is_optional_type(t: Type[Any]) -> bool: - # todo: use typing get_args and get_origin in python 3.8 - if hasattr(t, "__origin__"): - return t.__origin__ is Union and type(None) in t.__args__ - return False + return get_origin(t) is Union and type(None) in get_args(t) def extract_optional_type(t: Type[Any]) -> Any: @@ -49,7 +53,11 @@ def extract_optional_type(t: Type[Any]) -> Any: def is_literal_type(hint: Type[Any]) -> bool: - return hasattr(hint, "__origin__") and hint.__origin__ is Literal + return get_origin(hint) is Literal + + +def is_newtype_type(t: Type[Any]) -> bool: + return hasattr(t, "__supertype__") def is_typeddict(t: Any) -> bool: @@ -58,15 +66,34 @@ def is_typeddict(t: Any) -> bool: def is_list_generic_type(t: Any) -> bool: try: - o = get_origin(t) - return issubclass(o, list) or issubclass(o, C_Sequence) - except Exception: + return issubclass(get_origin(t), C_Sequence) + except TypeError: return False def is_dict_generic_type(t: Any) -> bool: try: - o = get_origin(t) - return issubclass(o, dict) or issubclass(o, C_Mapping) - except Exception: + return issubclass(get_origin(t), C_Mapping) + except TypeError: return False + + +def extract_inner_type(hint: Type[Any]) -> Type[Any]: + """Gets the inner type from Literal, Optional and NewType + + Args: + hint (Type[Any]): Any type + + Returns: + Type[Any]: Inner type if hint was Literal, Optional or NewType, otherwise hint + """ + if is_literal_type(hint): + # assume that all literals are of the same type + return extract_inner_type(type(get_args(hint)[0])) + if is_optional_type(hint): + # extract optional type and call recursively + return extract_inner_type(get_args(hint)[0]) + if is_newtype_type(hint): + # descend into supertypes of NewType + return extract_inner_type(hint.__supertype__) + return hint diff --git a/dlt/common/utils.py b/dlt/common/utils.py index 835c57d58f..18d65d45ac 100644 --- a/dlt/common/utils.py +++ b/dlt/common/utils.py @@ -1,13 +1,16 @@ -from functools import wraps import os +from pathlib import Path +import sys import base64 -from contextlib import contextmanager import hashlib -from os import environ import secrets +from contextlib import contextmanager +from functools import wraps +from os import environ + from typing import Any, Dict, Iterator, Optional, Sequence, TypeVar, Mapping, List, TypedDict, Union -from dlt.common.typing import StrAny, DictStrAny, StrStr, TFun +from dlt.common.typing import AnyFun, StrAny, DictStrAny, StrStr, TFun T = TypeVar("T") @@ -22,7 +25,11 @@ def uniq_id(len_: int = 16) -> str: def digest128(v: str) -> str: - return base64.b64encode(hashlib.shake_128(v.encode("utf-8")).digest(15)).decode('ascii') + return base64.b64encode( + hashlib.shake_128( + v.encode("utf-8") + ).digest(15) + ).decode('ascii') def digest256(v: str) -> str: @@ -93,7 +100,7 @@ def flatten_dicts_of_dicts(dicts: Mapping[str, Any]) -> Sequence[Any]: def tuplify_list_of_dicts(dicts: Sequence[DictStrAny]) -> Sequence[DictStrAny]: """ - Transform dicts with single key into {"key": orig_key, "value": orig_value} + Transform list of dictionaries with single key into single dictionary of {"key": orig_key, "value": orig_value} """ for d in dicts: if len(d) > 1: @@ -152,16 +159,16 @@ def custom_environ(env: StrStr) -> Iterator[None]: def with_custom_environ(f: TFun) -> TFun: - @wraps(f) - def _wrap(*args: Any, **kwargs: Any) -> Any: - saved_environ = os.environ.copy() - try: - return f(*args, **kwargs) - finally: - os.environ.clear() - os.environ.update(saved_environ) + @wraps(f) + def _wrap(*args: Any, **kwargs: Any) -> Any: + saved_environ = os.environ.copy() + try: + return f(*args, **kwargs) + finally: + os.environ.clear() + os.environ.update(saved_environ) - return _wrap # type: ignore + return _wrap # type: ignore def encoding_for_mode(mode: str) -> Optional[str]: @@ -169,3 +176,15 @@ def encoding_for_mode(mode: str) -> Optional[str]: return None else: return "utf-8" + + +def entry_point_file_stem() -> str: + if len(sys.argv) > 0 and os.path.isfile(sys.argv[0]): + return Path(sys.argv[0]).stem + return None + + +def is_inner_function(f: AnyFun) -> bool: + """Checks if f is defined within other function""" + # inner functions have full nesting path in their qualname + return "" in f.__qualname__ diff --git a/dlt/common/validation.py b/dlt/common/validation.py index 3c6cec9ad9..c13e54dd8b 100644 --- a/dlt/common/validation.py +++ b/dlt/common/validation.py @@ -8,13 +8,13 @@ TCustomValidator = Callable[[str, str, Any, Any], bool] -def validate_dict(schema: Type[_TypedDict], doc: StrAny, path: str, filter_f: TFilterFuc = None, validator_f: TCustomValidator = None) -> None: +def validate_dict(spec: Type[_TypedDict], doc: StrAny, path: str, filter_f: TFilterFuc = None, validator_f: TCustomValidator = None) -> None: # pass through filter filter_f = filter_f or (lambda _: True) # cannot validate anything validator_f = validator_f or (lambda p, pk, pv, t: False) - allowed_props = get_type_hints(schema) + allowed_props = get_type_hints(spec) required_props = {k: v for k, v in allowed_props.items() if not is_optional_type(v)} # remove optional props props = {k: v for k, v in doc.items() if filter_f(k)} @@ -34,7 +34,7 @@ def verify_prop(pk: str, pv: Any, t: Any) -> None: if is_literal_type(t): a_l = get_args(t) if pv not in a_l: - raise DictValidationException(f"In {path}: field {pk} value {pv} not in allowed {a_l}", path, pk, pv) + raise DictValidationException(f"In {path}: field {pk} value {pv} not in allowed {a_l}", path, pk, pv) elif t in [int, bool, str, float]: if not isinstance(pv, t): raise DictValidationException(f"In {path}: field {pk} value {pv} has invalid type {type(pv).__name__} while {t.__name__} is expected", path, pk, pv) diff --git a/dlt/common/wei.py b/dlt/common/wei.py index 53babc23fc..218e5eee3a 100644 --- a/dlt/common/wei.py +++ b/dlt/common/wei.py @@ -1,8 +1,7 @@ from typing import Union -from dlt.common import Decimal from dlt.common.typing import TVariantRV, SupportsVariant -from dlt.common.arithmetics import default_context, decimal +from dlt.common.arithmetics import default_context, decimal, Decimal # default scale of EVM based blockchain WEI_SCALE = 18 diff --git a/dlt/dbt_runner/__init__.py b/dlt/dbt_runner/__init__.py index 7df9f7aa35..e69de29bb2 100644 --- a/dlt/dbt_runner/__init__.py +++ b/dlt/dbt_runner/__init__.py @@ -1 +0,0 @@ -from ._version import __version__ \ No newline at end of file diff --git a/dlt/dbt_runner/_version.py b/dlt/dbt_runner/_version.py deleted file mode 100644 index 3dc1f76bc6..0000000000 --- a/dlt/dbt_runner/_version.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = "0.1.0" diff --git a/dlt/dbt_runner/configuration.py b/dlt/dbt_runner/configuration.py index f9279c20b1..b0217e9267 100644 --- a/dlt/dbt_runner/configuration.py +++ b/dlt/dbt_runner/configuration.py @@ -1,70 +1,57 @@ +import dataclasses +from os import environ from typing import List, Optional, Type from dlt.common.typing import StrAny, TSecretValue -from dlt.common.configuration import make_configuration -from dlt.common.configuration.providers import environ -from dlt.common.configuration import PoolRunnerConfiguration, TPoolType, PostgresCredentials, GcpClientCredentials - -from . import __version__ - - -class DBTRunnerConfiguration(PoolRunnerConfiguration): - POOL_TYPE: TPoolType = "none" - STOP_AFTER_RUNS: int = 1 - PACKAGE_VOLUME_PATH: str = "_storage/dbt_runner" - PACKAGE_REPOSITORY_URL: str = "https://github.com/scale-vector/rasa_semantic_schema_customization.git" - PACKAGE_REPOSITORY_BRANCH: Optional[str] = None - PACKAGE_REPOSITORY_SSH_KEY: TSecretValue = TSecretValue("") # the default is empty value which will disable custom SSH KEY - PACKAGE_PROFILES_DIR: str = "." - PACKAGE_PROFILE_PREFIX: str = "rasa_semantic_schema" - PACKAGE_SOURCE_TESTS_SELECTOR: str = "tag:prerequisites" - PACKAGE_ADDITIONAL_VARS: Optional[StrAny] = None - PACKAGE_RUN_PARAMS: List[str] = ["--fail-fast"] - AUTO_FULL_REFRESH_WHEN_OUT_OF_SYNC: bool = True - - SOURCE_SCHEMA_PREFIX: str = None - DEST_SCHEMA_PREFIX: Optional[str] = None - - @classmethod - def check_integrity(cls) -> None: - if cls.PACKAGE_REPOSITORY_SSH_KEY and cls.PACKAGE_REPOSITORY_SSH_KEY[-1] != "\n": +from dlt.common.configuration import resolve_configuration, configspec +from dlt.common.configuration.providers import EnvironProvider +from dlt.common.configuration.specs import RunConfiguration, PoolRunnerConfiguration, TPoolType, PostgresCredentials, GcpClientCredentials + + +@configspec +class DBTRunnerConfiguration(RunConfiguration, PoolRunnerConfiguration): + pool_type: TPoolType = "none" + stop_after_runs: int = 1 + package_volume_path: str = "/var/local/app" + package_repository_url: str = "https://github.com/scale-vector/rasa_semantic_schema_customization.git" + package_repository_branch: Optional[str] = None + package_repository_ssh_key: TSecretValue = TSecretValue("") # the default is empty value which will disable custom SSH KEY + package_profiles_dir: str = "." + package_profile_prefix: str = "rasa_semantic_schema" + package_source_tests_selector: str = "tag:prerequisites" + package_additional_vars: Optional[StrAny] = None + package_run_params: List[str] = dataclasses.field(default_factory=lambda: ["--fail-fast"]) + auto_full_refresh_when_out_of_sync: bool = True + + source_schema_prefix: str = None + dest_schema_prefix: Optional[str] = None + + def check_integrity(self) -> None: + if self.package_repository_ssh_key and self.package_repository_ssh_key[-1] != "\n": # must end with new line, otherwise won't be parsed by Crypto - cls.PACKAGE_REPOSITORY_SSH_KEY = TSecretValue(cls.PACKAGE_REPOSITORY_SSH_KEY + "\n") - if cls.STOP_AFTER_RUNS != 1: + self.package_repository_ssh_key = TSecretValue(self.package_repository_ssh_key + "\n") + if self.stop_after_runs != 1: # always stop after one run - cls.STOP_AFTER_RUNS = 1 + self.stop_after_runs = 1 -class DBTRunnerProductionConfiguration(DBTRunnerConfiguration): - PACKAGE_VOLUME_PATH: str = "/var/local/app" # this is actually not exposed as volume - PACKAGE_REPOSITORY_URL: str = None - - -def gen_configuration_variant(initial_values: StrAny = None) -> Type[DBTRunnerConfiguration]: +def gen_configuration_variant(initial_values: StrAny = None) -> DBTRunnerConfiguration: # derive concrete config depending on env vars present DBTRunnerConfigurationImpl: Type[DBTRunnerConfiguration] - DBTRunnerProductionConfigurationImpl: Type[DBTRunnerProductionConfiguration] + environ = EnvironProvider() - source_schema_prefix = environ.get_key("DEFAULT_DATASET", type(str)) + source_schema_prefix: str = environ.get_value("dataset_name", type(str)) # type: ignore - if environ.get_key("PROJECT_ID", type(str), namespace=GcpClientCredentials.__namespace__): + if environ.get_value("project_id", type(str), GcpClientCredentials.__namespace__): + @configspec class DBTRunnerConfigurationPostgres(PostgresCredentials, DBTRunnerConfiguration): SOURCE_SCHEMA_PREFIX: str = source_schema_prefix DBTRunnerConfigurationImpl = DBTRunnerConfigurationPostgres - class DBTRunnerProductionConfigurationPostgres(DBTRunnerProductionConfiguration, DBTRunnerConfigurationPostgres): - pass - # SOURCE_SCHEMA_PREFIX: str = source_schema_prefix - DBTRunnerProductionConfigurationImpl = DBTRunnerProductionConfigurationPostgres - else: + @configspec class DBTRunnerConfigurationGcp(GcpClientCredentials, DBTRunnerConfiguration): SOURCE_SCHEMA_PREFIX: str = source_schema_prefix DBTRunnerConfigurationImpl = DBTRunnerConfigurationGcp - class DBTRunnerProductionConfigurationGcp(DBTRunnerProductionConfiguration, DBTRunnerConfigurationGcp): - pass - # SOURCE_SCHEMA_PREFIX: str = source_schema_prefix - DBTRunnerProductionConfigurationImpl = DBTRunnerProductionConfigurationGcp - - return make_configuration(DBTRunnerConfigurationImpl, DBTRunnerProductionConfigurationImpl, initial_values=initial_values) + return resolve_configuration(DBTRunnerConfigurationImpl(), initial_value=initial_values) diff --git a/dlt/dbt_runner/runner.py b/dlt/dbt_runner/runner.py index 5ad572fe54..9d34d360ce 100644 --- a/dlt/dbt_runner/runner.py +++ b/dlt/dbt_runner/runner.py @@ -1,15 +1,15 @@ -from typing import Optional, Sequence, Tuple, Type +from typing import Optional, Sequence, Tuple from git import GitError from prometheus_client import REGISTRY, Gauge, CollectorRegistry, Info from prometheus_client.metrics import MetricWrapperBase -from dlt.common.configuration import GcpClientCredentials from dlt.common import logger +from dlt.cli import TRunnerArgs from dlt.common.typing import DictStrAny, DictStrStr, StrAny from dlt.common.logger import is_json_logging from dlt.common.telemetry import get_logging_extras -from dlt.common.file_storage import FileStorage -from dlt.cli import TRunnerArgs +from dlt.common.configuration.specs import GcpClientCredentials +from dlt.common.storages import FileStorage from dlt.common.runners import initialize_runner, run_pool from dlt.common.telemetry import TRunMetrics @@ -20,7 +20,7 @@ CLONED_PACKAGE_NAME = "dbt_package" -CONFIG: Type[DBTRunnerConfiguration] = None +CONFIG: DBTRunnerConfiguration = None storage: FileStorage = None dbt_package_vars: StrAny = None global_args: Sequence[str] = None @@ -32,30 +32,30 @@ def create_folders() -> Tuple[FileStorage, StrAny, Sequence[str], str, str]: - storage = FileStorage(CONFIG.PACKAGE_VOLUME_PATH, makedirs=True) - dbt_package_vars: DictStrAny = { - "source_schema_prefix": CONFIG.SOURCE_SCHEMA_PREFIX + storage_ = FileStorage(CONFIG.package_volume_path, makedirs=True) + dbt_package_vars_: DictStrAny = { + "source_schema_prefix": CONFIG.source_schema_prefix } - if CONFIG.DEST_SCHEMA_PREFIX: - dbt_package_vars["dest_schema_prefix"] = CONFIG.DEST_SCHEMA_PREFIX - if CONFIG.PACKAGE_ADDITIONAL_VARS: - dbt_package_vars.update(CONFIG.PACKAGE_ADDITIONAL_VARS) + if CONFIG.dest_schema_prefix: + dbt_package_vars_["dest_schema_prefix"] = CONFIG.dest_schema_prefix + if CONFIG.package_additional_vars: + dbt_package_vars_.update(CONFIG.package_additional_vars) # initialize dbt logging, returns global parameters to dbt command - global_args = initialize_dbt_logging(CONFIG.LOG_LEVEL, is_json_logging(CONFIG.LOG_FORMAT)) + global_args_ = initialize_dbt_logging(CONFIG.log_level, is_json_logging(CONFIG.log_format)) # generate path for the dbt package repo - repo_path = storage._make_path(CLONED_PACKAGE_NAME) + repo_path_ = storage_.make_full_path(CLONED_PACKAGE_NAME) # generate profile name - profile_name: str = None - if CONFIG.PACKAGE_PROFILE_PREFIX: - if issubclass(CONFIG, GcpClientCredentials): - profile_name = "%s_bigquery" % (CONFIG.PACKAGE_PROFILE_PREFIX) + profile_name_: str = None + if CONFIG.package_profile_prefix: + if isinstance(CONFIG, GcpClientCredentials): + profile_name_ = "%s_bigquery" % (CONFIG.package_profile_prefix) else: - profile_name = "%s_redshift" % (CONFIG.PACKAGE_PROFILE_PREFIX) + profile_name_ = "%s_redshift" % (CONFIG.package_profile_prefix) - return storage, dbt_package_vars, global_args, repo_path, profile_name + return storage_, dbt_package_vars_, global_args_, repo_path_, profile_name_ def create_gauges(registry: CollectorRegistry) -> Tuple[MetricWrapperBase, MetricWrapperBase]: @@ -69,7 +69,7 @@ def run_dbt(command: str, command_args: Sequence[str] = None) -> Sequence[dbt_re logger.info(f"Exec dbt command: {global_args} {command} {command_args} {dbt_package_vars} on profile {profile_name or ''}") return run_dbt_command( repo_path, command, - CONFIG.PACKAGE_PROFILES_DIR, + CONFIG.package_profiles_dir, profile_name=profile_name, command_args=command_args, global_args=global_args, @@ -109,8 +109,8 @@ def initialize_package(with_git_command: Optional[str]) -> None: # cleanup package folder if storage.has_folder(CLONED_PACKAGE_NAME): storage.delete_folder(CLONED_PACKAGE_NAME, recursively=True) - logger.info(f"Will clone {CONFIG.PACKAGE_REPOSITORY_URL} head {CONFIG.PACKAGE_REPOSITORY_BRANCH} into {repo_path}") - clone_repo(CONFIG.PACKAGE_REPOSITORY_URL, repo_path, branch=CONFIG.PACKAGE_REPOSITORY_BRANCH, with_git_command=with_git_command) + logger.info(f"Will clone {CONFIG.package_repository_url} head {CONFIG.package_repository_branch} into {repo_path}") + clone_repo(CONFIG.package_repository_url, repo_path, branch=CONFIG.package_repository_branch, with_git_command=with_git_command) run_dbt("deps") except Exception: # delete folder so we start clean next time @@ -120,7 +120,7 @@ def initialize_package(with_git_command: Optional[str]) -> None: def ensure_newest_package() -> None: - with git_custom_key_command(CONFIG.PACKAGE_REPOSITORY_SSH_KEY) as ssh_command: + with git_custom_key_command(CONFIG.package_repository_ssh_key) as ssh_command: try: ensure_remote_head(repo_path, with_git_command=ssh_command) except GitError as err: @@ -134,8 +134,8 @@ def run_db_steps() -> Sequence[dbt_results.BaseResult]: ensure_newest_package() # check if raw schema exists try: - if CONFIG.PACKAGE_SOURCE_TESTS_SELECTOR: - run_dbt("test", ["-s", CONFIG.PACKAGE_SOURCE_TESTS_SELECTOR]) + if CONFIG.package_source_tests_selector: + run_dbt("test", ["-s", CONFIG.package_source_tests_selector]) except DBTProcessingError as err: raise PrerequisitesException() from err @@ -143,12 +143,12 @@ def run_db_steps() -> Sequence[dbt_results.BaseResult]: run_dbt("seed") # throws DBTProcessingError try: - return run_dbt("run", CONFIG.PACKAGE_RUN_PARAMS) + return run_dbt("run", CONFIG.package_run_params) except DBTProcessingError as e: # detect incremental model out of sync - if is_incremental_schema_out_of_sync_error(e.results) and CONFIG.AUTO_FULL_REFRESH_WHEN_OUT_OF_SYNC: + if is_incremental_schema_out_of_sync_error(e.results) and CONFIG.auto_full_refresh_when_out_of_sync: logger.warning(f"Attempting full refresh due to incremental model out of sync on {e.results.message}") - return run_dbt("run", CONFIG.PACKAGE_RUN_PARAMS + ["--full-refresh"]) + return run_dbt("run", CONFIG.package_run_params + ["--full-refresh"]) else: raise @@ -172,7 +172,7 @@ def run(_: None) -> TRunMetrics: raise -def configure(C: Type[DBTRunnerConfiguration], collector: CollectorRegistry) -> None: +def configure(C: DBTRunnerConfiguration, collector: CollectorRegistry) -> None: global CONFIG global storage, dbt_package_vars, global_args, repo_path, profile_name global model_elapsed_gauge, model_exec_info @@ -183,7 +183,7 @@ def configure(C: Type[DBTRunnerConfiguration], collector: CollectorRegistry) -> model_elapsed_gauge, model_exec_info = create_gauges(REGISTRY) except ValueError as v: # ignore re-creation of gauges - if "Duplicated timeseries" not in str(v): + if "Duplicated" not in str(v): raise diff --git a/dlt/dbt_runner/utils.py b/dlt/dbt_runner/utils.py index 2d426bd727..51a07d336a 100644 --- a/dlt/dbt_runner/utils.py +++ b/dlt/dbt_runner/utils.py @@ -50,7 +50,7 @@ def git_custom_key_command(private_key: Optional[str]) -> Iterator[str]: def ensure_remote_head(repo_path: str, with_git_command: Optional[str] = None) -> None: # update remotes and check if heads are same. ignores locally modified files repo = Repo(repo_path) - # use custom environemnt if specified + # use custom environment if specified with repo.git.custom_environment(GIT_SSH_COMMAND=with_git_command): # update origin repo.remote().update() diff --git a/dlt/extract/decorators.py b/dlt/extract/decorators.py new file mode 100644 index 0000000000..4f17525cdf --- /dev/null +++ b/dlt/extract/decorators.py @@ -0,0 +1,329 @@ +import inspect +from types import ModuleType +from makefun import wraps +from typing import Any, Callable, Dict, Iterator, List, NamedTuple, Optional, Tuple, Type, TypeVar, Union, overload + +from dlt.common.configuration import with_config, get_fun_spec +from dlt.common.configuration.specs import BaseConfiguration +from dlt.common.exceptions import ArgumentsOverloadException +from dlt.common.schema.schema import Schema +from dlt.common.schema.typing import TTableSchemaColumns, TWriteDisposition +from dlt.common.typing import AnyFun, ParamSpec, TDataItems +from dlt.common.utils import is_inner_function +from dlt.extract.exceptions import InvalidResourceDataTypeFunctionNotAGenerator + +from dlt.extract.typing import TTableHintTemplate, TFunHintTemplate +from dlt.extract.source import DltResource, DltSource + + +class SourceInfo(NamedTuple): + SPEC: Type[BaseConfiguration] + f: AnyFun + module: ModuleType + + +_SOURCES: Dict[str, SourceInfo] = {} + +TSourceFunParams = ParamSpec("TSourceFunParams") +TResourceFunParams = ParamSpec("TResourceFunParams") + + +@overload +def source(func: Callable[TSourceFunParams, Any], /, name: str = None, schema: Schema = None, spec: Type[BaseConfiguration] = None) -> Callable[TSourceFunParams, DltSource]: + ... + +@overload +def source(func: None = ..., /, name: str = None, schema: Schema = None, spec: Type[BaseConfiguration] = None) -> Callable[[Callable[TSourceFunParams, Any]], Callable[TSourceFunParams, DltSource]]: + ... + +def source(func: Optional[AnyFun] = None, /, name: str = None, schema: Schema = None, spec: Type[BaseConfiguration] = None) -> Any: + + # if name and schema: + # raise ArgumentsOverloadException( + # "source name cannot be set if schema is present", + # "source", + # "You can provide either the Schema instance directly in `schema` argument or the name of ") + + def decorator(f: Callable[TSourceFunParams, Any]) -> Callable[TSourceFunParams, DltSource]: + nonlocal schema, name + + # source name is passed directly or taken from decorated function name + name = name or f.__name__ + + if not schema: + # create or load default schema + # TODO: we need a convention to load ie. load the schema from file with name_schema.yaml + schema = Schema(name) + + # wrap source extraction function in configuration with namespace + conf_f = with_config(f, spec=spec, namespaces=("source", name)) + + @wraps(conf_f, func_name=name) + def _wrap(*args: Any, **kwargs: Any) -> DltSource: + rv = conf_f(*args, **kwargs) + + # if generator, consume it immediately + if inspect.isgenerator(rv): + rv = list(rv) + + # def check_rv_type(rv: Any) -> None: + # pass + + # # check if return type is list or tuple + # if isinstance(rv, (list, tuple)): + # # check all returned elements + # for v in rv: + # check_rv_type(v) + # else: + # check_rv_type(rv) + + # convert to source + return DltSource.from_data(schema, rv) + + # get spec for wrapped function + SPEC = get_fun_spec(conf_f) + # store the source information + _SOURCES[_wrap.__qualname__] = SourceInfo(SPEC, _wrap, inspect.getmodule(f)) + + # the typing is right, but makefun.wraps does not preserve signatures + return _wrap # type: ignore + + if func is None: + # we're called with parens. + return decorator + + if not callable(func): + raise ValueError("First parameter to the source must be a callable.") + + # we're called as @source without parens. + return decorator(func) + + +# @source +# def reveal_1() -> None: +# pass + +# @source(name="revel") +# def reveal_2() -> None: +# pass + + +# def revel_3(v) -> int: +# pass + + +# reveal_type(reveal_1) +# reveal_type(reveal_1()) + +# reveal_type(reveal_2) +# reveal_type(reveal_2()) + +# reveal_type(source(revel_3)) +# reveal_type(source(revel_3)("s")) + +@overload +def resource( + data: Callable[TResourceFunParams, Any], + /, + name: str = None, + table_name_fun: TFunHintTemplate[str] = None, + write_disposition: TTableHintTemplate[TWriteDisposition] = None, + columns: TTableHintTemplate[TTableSchemaColumns] = None, + selected: bool = True, + depends_on: DltResource = None, + spec: Type[BaseConfiguration] = None +) -> Callable[TResourceFunParams, DltResource]: + ... + +@overload +def resource( + data: None = ..., + /, + name: str = None, + table_name_fun: TFunHintTemplate[str] = None, + write_disposition: TTableHintTemplate[TWriteDisposition] = None, + columns: TTableHintTemplate[TTableSchemaColumns] = None, + selected: bool = True, + depends_on: DltResource = None, + spec: Type[BaseConfiguration] = None +) -> Callable[[Callable[TResourceFunParams, Any]], Callable[TResourceFunParams, DltResource]]: + ... + + +# @overload +# def resource( +# data: Union[DltSource, DltResource, Sequence[DltSource], Sequence[DltResource]], +# / +# ) -> DltResource: +# ... + + +@overload +def resource( + data: Union[List[Any], Tuple[Any], Iterator[Any]], + /, + name: str = None, + table_name_fun: TFunHintTemplate[str] = None, + write_disposition: TTableHintTemplate[TWriteDisposition] = None, + columns: TTableHintTemplate[TTableSchemaColumns] = None, + selected: bool = True, + depends_on: DltResource = None, + spec: Type[BaseConfiguration] = None +) -> DltResource: + ... + + +def resource( + data: Optional[Any] = None, + /, + name: str = None, + table_name_fun: TFunHintTemplate[str] = None, + write_disposition: TTableHintTemplate[TWriteDisposition] = None, + columns: TTableHintTemplate[TTableSchemaColumns] = None, + selected: bool = True, + depends_on: DltResource = None, + spec: Type[BaseConfiguration] = None +) -> Any: + + def make_resource(_name: str, _data: Any) -> DltResource: + table_template = DltResource.new_table_template(table_name_fun or _name, write_disposition=write_disposition, columns=columns) + return DltResource.from_data(_data, _name, table_template, selected, depends_on) + + + def decorator(f: Callable[TResourceFunParams, Any]) -> Callable[TResourceFunParams, DltResource]: + resource_name = name or f.__name__ + + # if f is not a generator (does not yield) raise Exception + if not inspect.isgeneratorfunction(inspect.unwrap(f)): + raise InvalidResourceDataTypeFunctionNotAGenerator(resource_name, f, type(f)) + + # do not inject config values for inner functions, we assume that they are part of the source + SPEC: Type[BaseConfiguration] = None + if is_inner_function(f): + conf_f = f + else: + # wrap source extraction function in configuration with namespace + conf_f = with_config(f, spec=spec, namespaces=("resource", resource_name)) + # get spec for wrapped function + SPEC = get_fun_spec(conf_f) + + # @wraps(conf_f, func_name=resource_name) + # def _wrap(*args: Any, **kwargs: Any) -> DltResource: + # return make_resource(resource_name, f(*args, **kwargs)) + + # store the standalone resource information + if SPEC: + _SOURCES[f.__qualname__] = SourceInfo(SPEC, f, inspect.getmodule(f)) + + # the typing is right, but makefun.wraps does not preserve signatures + return make_resource(resource_name, f) + + # if data is callable or none use decorator + if data is None: + # we're called with parens. + return decorator + + if callable(data): + return decorator(data) + else: + return make_resource(name, data) + + +def _get_source_for_inner_function(f: AnyFun) -> Optional[SourceInfo]: + # find source function + parts = f.__qualname__.split(".") + parent_fun = ".".join(parts[:-2]) + return _SOURCES.get(parent_fun) + + +# @resource +# def reveal_1() -> None: +# pass + +# @resource(name="revel") +# def reveal_2() -> None: +# pass + + +# def revel_3(v) -> int: +# pass + + +# reveal_type(reveal_1) +# reveal_type(reveal_1()) + +# reveal_type(reveal_2) +# reveal_type(reveal_2()) + +# reveal_type(resource(revel_3)) +# reveal_type(resource(revel_3)("s")) + + +# reveal_type(resource([], name="aaaa")) +# reveal_type(resource("aaaaa", name="aaaa")) + +# name of dlt metadata as part of the item +# DLT_METADATA_FIELD = "_dlt_meta" + + +# class TEventDLTMeta(TypedDict, total=False): +# table_name: str # a root table in which store the event + + +# def append_dlt_meta(item: TBoundItem, name: str, value: Any) -> TBoundItem: +# if isinstance(item, abc.Sequence): +# for i in item: +# i.setdefault(DLT_METADATA_FIELD, {})[name] = value +# elif isinstance(item, dict): +# item.setdefault(DLT_METADATA_FIELD, {})[name] = value + +# return item + + +# def with_table_name(item: TBoundItem, table_name: str) -> TBoundItem: +# # normalize table name before adding +# return append_dlt_meta(item, "table_name", table_name) + + +# def get_table_name(item: StrAny) -> Optional[str]: +# if DLT_METADATA_FIELD in item: +# meta: TEventDLTMeta = item[DLT_METADATA_FIELD] +# return meta.get("table_name", None) +# return None + + +# def with_retry(max_retries: int = 3, retry_sleep: float = 1.0) -> Callable[[Callable[_TFunParams, TBoundItem]], Callable[_TFunParams, TBoundItem]]: + +# def decorator(f: Callable[_TFunParams, TBoundItem]) -> Callable[_TFunParams, TBoundItem]: + +# def _wrap(*args: Any, **kwargs: Any) -> TBoundItem: +# attempts = 0 +# while True: +# try: +# return f(*args, **kwargs) +# except Exception as exc: +# if attempts == max_retries: +# raise +# attempts += 1 +# logger.warning(f"Exception {exc} in iterator, retrying {attempts} / {max_retries}") +# sleep(retry_sleep) + +# return _wrap + +# return decorator + + +TBoundItems = TypeVar("TBoundItems", bound=TDataItems) +TDeferred = Callable[[], TBoundItems] +TDeferredFunParams = ParamSpec("TDeferredFunParams") + + +def defer(f: Callable[TDeferredFunParams, TBoundItems]) -> Callable[TDeferredFunParams, TDeferred[TBoundItems]]: + + @wraps(f) + def _wrap(*args: Any, **kwargs: Any) -> TDeferred[TBoundItems]: + def _curry() -> TBoundItems: + return f(*args, **kwargs) + return _curry + + return _wrap # type: ignore diff --git a/dlt/extract/exceptions.py b/dlt/extract/exceptions.py index 6582b526b7..a3fc162113 100644 --- a/dlt/extract/exceptions.py +++ b/dlt/extract/exceptions.py @@ -1,5 +1,116 @@ +from typing import Any, Type from dlt.common.exceptions import DltException class ExtractorException(DltException): pass + + +class DltSourceException(DltException): + pass + + +class DltResourceException(DltSourceException): + def __init__(self, resource_name: str, msg: str) -> None: + self.resource_name = resource_name + super().__init__(msg) + + +class PipeException(DltException): + pass + + +class CreatePipeException(PipeException): + pass + + +class PipeItemProcessingError(PipeException): + pass + + +# class InvalidIteratorException(PipelineException): +# def __init__(self, iterator: Any) -> None: +# super().__init__(f"Unsupported source iterator or iterable type: {type(iterator).__name__}") + + +# class InvalidItemException(PipelineException): +# def __init__(self, item: Any) -> None: +# super().__init__(f"Source yielded unsupported item type: {type(item).__name}. Only dictionaries, sequences and deferred items allowed.") + + +class ResourceNameMissing(DltResourceException): + def __init__(self) -> None: + super().__init__(None, """Resource name is missing. If you create a resource directly from data ie. from a list you must pass the name explicitly in `name` argument. + Please note that for resources created from functions or generators, the name is the function name by default.""") + + +class DependentResourceIsNotCallable(DltResourceException): + def __init__(self, resource_name: str) -> None: + super().__init__(resource_name, f"Attempted to call the dependent resource {resource_name}. Do not call the dependent resources. They will be called only when iterated.") + + +class ResourceNotFoundError(DltResourceException, KeyError): + def __init__(self, resource_name: str, context: str) -> None: + self.resource_name = resource_name + super().__init__(resource_name, f"Resource with a name {resource_name} could not be found. {context}") + + +class InvalidResourceDataType(DltResourceException): + def __init__(self, resource_name: str, item: Any, _typ: Type[Any], msg: str) -> None: + self.item = item + self._typ = _typ + super().__init__(resource_name, f"Cannot create resource {resource_name} from specified data. " + msg) + + +class InvalidResourceDataTypeAsync(InvalidResourceDataType): + def __init__(self, resource_name: str, item: Any,_typ: Type[Any]) -> None: + super().__init__(resource_name, item, _typ, "Async iterators and generators are not valid resources. Please use standard iterators and generators that yield Awaitables instead (for example by yielding from async function without await") + + +class InvalidResourceDataTypeBasic(InvalidResourceDataType): + def __init__(self, resource_name: str, item: Any,_typ: Type[Any]) -> None: + super().__init__(resource_name, item, _typ, f"Resources cannot be strings or dictionaries but {_typ.__name__} was provided. Please pass your data in a list or as a function yielding items. If you want to process just one data item, enclose it in a list.") + + +class InvalidResourceDataTypeFunctionNotAGenerator(InvalidResourceDataType): + def __init__(self, resource_name: str, item: Any,_typ: Type[Any]) -> None: + super().__init__(resource_name, item, _typ, "Please make sure that function decorated with @resource uses 'yield' to return the data.") + + +class InvalidResourceDataTypeMultiplePipes(InvalidResourceDataType): + def __init__(self, resource_name: str, item: Any,_typ: Type[Any]) -> None: + super().__init__(resource_name, item, _typ, "Resources with multiple parallel data pipes are not yet supported. This problem most often happens when you are creating a source with @source decorator that has several resources with the same name.") + + +class InvalidDependentResourceDataTypeGeneratorFunctionRequired(InvalidResourceDataType): + def __init__(self, resource_name: str, item: Any,_typ: Type[Any]) -> None: + super().__init__(resource_name, item, _typ, "Dependent resource must be a decorated function that takes data item as its only argument.") + + +class InvalidParentResourceDataType(InvalidResourceDataType): + def __init__(self, resource_name: str, item: Any,_typ: Type[Any]) -> None: + super().__init__(resource_name, item, _typ, f"A parent resource of {resource_name} is of type {_typ.__name__}. Did you forget to use '@resource` decorator or `resource` function?") + + +class InvalidParentResourceIsAFunction(DltResourceException): + def __init__(self, resource_name: str, func_name: str) -> None: + self.func_name = func_name + super().__init__(resource_name, f"A parent resource {func_name} of dependent resource {resource_name} is a function. Please decorate it with '@resource' or pass to 'resource' function.") + + +class TableNameMissing(DltSourceException): + def __init__(self) -> None: + super().__init__("""Table name is missing in table template. Please provide a string or a function that takes a data item as an argument""") + + +class InconsistentTableTemplate(DltSourceException): + def __init__(self, reason: str) -> None: + msg = f"A set of table hints provided to the resource is inconsistent: {reason}" + super().__init__(msg) + + +class DataItemRequiredForDynamicTableHints(DltSourceException): + def __init__(self, resource_name: str) -> None: + self.resource_name = resource_name + super().__init__(f"""An instance of resource's data required to generate table schema in resource {resource_name}. + One of table hints for that resource (typically table name) is a function and hint is computed separately for each instance of data extracted from that resource.""") diff --git a/dlt/extract/extract.py b/dlt/extract/extract.py new file mode 100644 index 0000000000..418f9627c8 --- /dev/null +++ b/dlt/extract/extract.py @@ -0,0 +1,101 @@ +import os +from typing import ClassVar, List + +from dlt.common.utils import uniq_id +from dlt.common.typing import TDataItems, TDataItem +from dlt.common.schema import utils, TSchemaUpdate +from dlt.common.storages import NormalizeStorage, DataItemStorage +from dlt.common.configuration.specs import NormalizeVolumeConfiguration + + +from dlt.extract.pipe import PipeIterator +from dlt.extract.source import DltResource, DltSource + + +class ExtractorStorage(DataItemStorage, NormalizeStorage): + EXTRACT_FOLDER: ClassVar[str] = "extract" + + def __init__(self, C: NormalizeVolumeConfiguration) -> None: + # data item storage with jsonl with pua encoding + super().__init__("puae-jsonl", True, C) + self.storage.create_folder(ExtractorStorage.EXTRACT_FOLDER, exists_ok=True) + + def create_extract_id(self) -> str: + extract_id = uniq_id() + self.storage.create_folder(self._get_extract_path(extract_id)) + return extract_id + + def commit_extract_files(self, extract_id: str, with_delete: bool = True) -> None: + extract_path = self._get_extract_path(extract_id) + for file in self.storage.list_folder_files(extract_path, to_root=False): + from_file = os.path.join(extract_path, file) + to_file = os.path.join(NormalizeStorage.EXTRACTED_FOLDER, file) + if with_delete: + self.storage.atomic_rename(from_file, to_file) + else: + # create hardlink which will act as a copy + self.storage.link_hard(from_file, to_file) + if with_delete: + self.storage.delete_folder(extract_path, recursively=True) + + def _get_data_item_path_template(self, load_id: str, schema_name: str, table_name: str) -> str: + template = NormalizeStorage.build_extracted_file_stem(schema_name, table_name, "%s") + return self.storage.make_full_path(os.path.join(self._get_extract_path(load_id), template)) + + def _get_extract_path(self, extract_id: str) -> str: + return os.path.join(ExtractorStorage.EXTRACT_FOLDER, extract_id) + + +def extract(source: DltSource, storage: ExtractorStorage, *, max_parallel_items: int = 100, workers: int = 5, futures_poll_interval: float = 0.01) -> TSchemaUpdate: + # TODO: add metrics: number of items processed, also per resource and table + dynamic_tables: TSchemaUpdate = {} + schema = source.schema + extract_id = storage.create_extract_id() + + def _write_item(table_name: str, item: TDataItems) -> None: + # normalize table name before writing so the name match the name in schema + # note: normalize function should be cached so there's almost no penalty on frequent calling + # note: column schema is not required for jsonl writer used here + # event.pop(DLT_METADATA_FIELD, None) # type: ignore + storage.write_data_item(extract_id, schema.name, schema.normalize_table_name(table_name), item, None) + + def _write_dynamic_table(resource: DltResource, item: TDataItem) -> None: + table_name = resource._table_name_hint_fun(item) + existing_table = dynamic_tables.get(table_name) + if existing_table is None: + dynamic_tables[table_name] = [resource.table_schema(item)] + else: + # quick check if deep table merge is required + if resource._table_has_other_dynamic_hints: + new_table = resource.table_schema(item) + # this merges into existing table in place + utils.merge_tables(existing_table[0], new_table) + else: + # if there are no other dynamic hints besides name then we just leave the existing partial table + pass + # write to storage with inferred table name + _write_item(table_name, item) + + # yield from all selected pipes + for pipe_item in PipeIterator.from_pipes(source.resources.selected_pipes, max_parallel_items=max_parallel_items, workers=workers, futures_poll_interval=futures_poll_interval): + # get partial table from table template + # TODO: many resources may be returned. if that happens the item meta must be present with table name and this name must match one of resources + # TDataItemMeta(table_name, requires_resource, write_disposition, columns, parent etc.) + resource = source.resources.find_by_pipe(pipe_item.pipe) + if resource._table_name_hint_fun: + if isinstance(pipe_item.item, List): + for item in pipe_item.item: + _write_dynamic_table(resource, item) + else: + _write_dynamic_table(resource, pipe_item.item) + else: + # write item belonging to table with static name + _write_item(resource.name, pipe_item.item) + + # flush all buffered writers + storage.close_writers(extract_id) + storage.commit_extract_files(extract_id) + + # returns set of partial tables + return dynamic_tables + diff --git a/dlt/extract/extractor_storage.py b/dlt/extract/extractor_storage.py deleted file mode 100644 index 32e71f6fec..0000000000 --- a/dlt/extract/extractor_storage.py +++ /dev/null @@ -1,41 +0,0 @@ -import semver - -from dlt.common.json import json_typed_dumps -from dlt.common.typing import Any -from dlt.common.utils import uniq_id -from dlt.common.schema import normalize_schema_name -from dlt.common.file_storage import FileStorage -from dlt.common.storages.versioned_storage import VersionedStorage -from dlt.common.storages.normalize_storage import NormalizeStorage - - -class ExtractorStorageBase(VersionedStorage): - def __init__(self, version: semver.VersionInfo, is_owner: bool, storage: FileStorage, normalize_storage: NormalizeStorage) -> None: - self.normalize_storage = normalize_storage - super().__init__(version, is_owner, storage) - - def create_temp_folder(self) -> str: - tf_name = uniq_id() - self.storage.create_folder(tf_name) - return tf_name - - def save_json(self, name: str, d: Any) -> None: - # saves json using typed encoder - self.storage.save(name, json_typed_dumps(d)) - - def commit_events(self, schema_name: str, processed_file_path: str, dest_file_stem: str, no_processed_events: int, load_id: str, with_delete: bool = True) -> str: - # schema name cannot contain underscores - if schema_name != normalize_schema_name(schema_name): - raise ValueError(schema_name) - - dest_name = NormalizeStorage.build_extracted_file_name(schema_name, dest_file_stem, no_processed_events, load_id) - # if no events extracted from tracker, file is not saved - if no_processed_events > 0: - # moves file to possibly external storage and place in the dest folder atomically - self.storage.copy_cross_storage_atomically( - self.normalize_storage.storage.storage_path, NormalizeStorage.EXTRACTED_FOLDER, processed_file_path, dest_name) - - if with_delete: - self.storage.delete(processed_file_path) - - return dest_name diff --git a/dlt/extract/pipe.py b/dlt/extract/pipe.py new file mode 100644 index 0000000000..5b71c5191d --- /dev/null +++ b/dlt/extract/pipe.py @@ -0,0 +1,447 @@ +import types +import asyncio +from asyncio import Future +from concurrent.futures import ThreadPoolExecutor +from copy import deepcopy +from threading import Thread +from typing import Optional, Sequence, Union, Callable, Iterable, Iterator, List, NamedTuple, Awaitable, Tuple, Type, TYPE_CHECKING + +from dlt.common.configuration.inject import with_config +from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec +from dlt.common.typing import TDataItem, TDataItems + +from dlt.extract.exceptions import CreatePipeException, PipeItemProcessingError +from dlt.extract.typing import TPipedDataItems + +if TYPE_CHECKING: + TItemFuture = Future[TDataItems] +else: + TItemFuture = Future + +from dlt.common.time import sleep + + +class PipeItem(NamedTuple): + item: TDataItems + step: int + pipe: "Pipe" + + +class ResolvablePipeItem(NamedTuple): + # mypy unable to handle recursive types, ResolvablePipeItem should take itself in "item" + item: Union[TPipedDataItems, Iterator[TPipedDataItems]] + step: int + pipe: "Pipe" + + +class FuturePipeItem(NamedTuple): + item: TItemFuture + step: int + pipe: "Pipe" + + +class SourcePipeItem(NamedTuple): + item: Union[Iterator[TPipedDataItems], Iterator[ResolvablePipeItem]] + step: int + pipe: "Pipe" + + +# pipeline step may be iterator of data items or mapping function that returns data item or another iterator +TPipeStep = Union[ + Iterable[TPipedDataItems], + Iterator[TPipedDataItems], + Callable[[TDataItems], TPipedDataItems], + Callable[[TDataItems], Iterator[TPipedDataItems]], + Callable[[TDataItems], Iterator[ResolvablePipeItem]] +] + + +class ForkPipe: + def __init__(self, pipe: "Pipe", step: int = -1) -> None: + self._pipes: List[Tuple["Pipe", int]] = [] + self.add_pipe(pipe, step) + + def add_pipe(self, pipe: "Pipe", step: int = -1) -> None: + if pipe not in self._pipes: + self._pipes.append((pipe, step)) + + def has_pipe(self, pipe: "Pipe") -> bool: + return pipe in [p[0] for p in self._pipes] + + def __call__(self, item: TDataItems) -> Iterator[ResolvablePipeItem]: + for i, (pipe, step) in enumerate(self._pipes): + _it = item if i == 0 else deepcopy(item) + # always start at the beginning + yield ResolvablePipeItem(_it, step, pipe) + + +class FilterItem: + def __init__(self, filter_f: Callable[[TDataItem], bool]) -> None: + self._filter_f = filter_f + + def __call__(self, item: TDataItems) -> Optional[TDataItems]: + # item may be a list TDataItem or a single TDataItem + if isinstance(item, list): + item = [i for i in item if self._filter_f(i)] + if not item: + # item was fully consumed by the filter + return None + return item + else: + return item if self._filter_f(item) else None + + +class Pipe: + def __init__(self, name: str, steps: List[TPipeStep] = None, parent: "Pipe" = None) -> None: + self.name = name + self._steps: List[TPipeStep] = steps or [] + self._backup_steps: List[TPipeStep] = None + self._pipe_id = f"{name}_{id(self)}" + self.parent = parent + + @classmethod + def from_iterable(cls, name: str, gen: Union[Iterable[TPipedDataItems], Iterator[TPipedDataItems]], parent: "Pipe" = None) -> "Pipe": + if isinstance(gen, Iterable): + gen = iter(gen) + return cls(name, [gen], parent=parent) + + @property + def head(self) -> TPipeStep: + return self._steps[0] + + @property + def tail(self) -> TPipeStep: + return self._steps[-1] + + @property + def steps(self) -> List[TPipeStep]: + return self._steps + + def __getitem__(self, i: int) -> TPipeStep: + return self._steps[i] + + def __len__(self) -> int: + return len(self._steps) + + def fork(self, child_pipe: "Pipe", child_step: int = -1) -> "Pipe": + if len(self._steps) == 0: + raise CreatePipeException("Cannot fork to empty pipe") + fork_step = self.tail + if not isinstance(fork_step, ForkPipe): + fork_step = ForkPipe(child_pipe, child_step) + self.add_step(fork_step) + else: + if not fork_step.has_pipe(child_pipe): + fork_step.add_pipe(child_pipe, child_step) + return self + + def clone(self) -> "Pipe": + p = Pipe(self.name, self._steps.copy(), self.parent) + # clone shares the id with the original + p._pipe_id = self._pipe_id + return p + + # def backup(self) -> None: + # if self.has_backup: + # raise PipeBackupException("Pipe backup already exists, restore pipe first") + # self._backup_steps = self._steps.copy() + + # @property + # def has_backup(self) -> bool: + # return self._backup_steps is not None + + + # def restore(self) -> None: + # if not self.has_backup: + # raise PipeBackupException("No pipe backup to restore") + # self._steps = self._backup_steps + # self._backup_steps = None + + def add_step(self, step: TPipeStep) -> "Pipe": + if len(self._steps) == 0 and self.parent is None: + # first element must be iterable or iterator + if not isinstance(step, (Iterable, Iterator)): + raise CreatePipeException("First step of independent pipe must be Iterable or Iterator") + else: + if isinstance(step, Iterable): + step = iter(step) + self._steps.append(step) + else: + if isinstance(step, (Iterable, Iterator)): + if self.parent is not None: + raise CreatePipeException("Iterable or Iterator cannot be a step in dependent pipe") + else: + raise CreatePipeException("Iterable or Iterator can only be a first step in independent pipe") + if not callable(step): + raise CreatePipeException("Pipe step must be a callable taking exactly one data item as input") + self._steps.append(step) + return self + + def full_pipe(self) -> "Pipe": + if self.parent: + pipe = self.parent.full_pipe().steps + else: + pipe = [] + + # return pipe with resolved dependencies + pipe.extend(self._steps) + return Pipe(self.name, pipe) + + def evaluate_head(self) -> None: + # if pipe head is callable then call it + if self.parent is None: + if callable(self.head): + self._steps[0] = self.head() # type: ignore + + def __repr__(self) -> str: + return f"Pipe {self.name} ({self._pipe_id}) at {id(self)}" + + +class PipeIterator(Iterator[PipeItem]): + + @configspec + class PipeIteratorConfiguration(BaseConfiguration): + max_parallel_items: int = 100 + workers: int = 5 + futures_poll_interval: float = 0.01 + + def __init__(self, max_parallel_items: int, workers: int, futures_poll_interval: float) -> None: + self.max_parallel_items = max_parallel_items + self.workers = workers + self.futures_poll_interval = futures_poll_interval + + self._async_pool: asyncio.AbstractEventLoop = None + self._async_pool_thread: Thread = None + self._thread_pool: ThreadPoolExecutor = None + self._sources: List[SourcePipeItem] = [] + self._futures: List[FuturePipeItem] = [] + + @classmethod + @with_config(spec=PipeIteratorConfiguration) + def from_pipe(cls, pipe: Pipe, *, max_parallel_items: int = 100, workers: int = 5, futures_poll_interval: float = 0.01) -> "PipeIterator": + if pipe.parent: + pipe = pipe.full_pipe() + # head must be iterator + pipe.evaluate_head() + assert isinstance(pipe.head, Iterator) + # create extractor + extract = cls(max_parallel_items, workers, futures_poll_interval) + # add as first source + extract._sources.append(SourcePipeItem(pipe.head, 0, pipe)) + return extract + + @classmethod + @with_config(spec=PipeIteratorConfiguration) + def from_pipes(cls, pipes: Sequence[Pipe], yield_parents: bool = True, *, max_parallel_items: int = 100, workers: int = 5, futures_poll_interval: float = 0.01) -> "PipeIterator": + extract = cls(max_parallel_items, workers, futures_poll_interval) + # TODO: consider removing cloning. pipe are single use and may be iterated only once, here we modify an immediately run + # clone all pipes before iterating (recursively) as we will fork them and this add steps + pipes = PipeIterator.clone_pipes(pipes) + + def _fork_pipeline(pipe: Pipe) -> None: + print(f"forking: {pipe.name}") + if pipe.parent: + # fork the parent pipe + pipe.parent.fork(pipe) + # make the parent yield by sending a clone of item to itself with position at the end + if yield_parents and pipe.parent in pipes: + # fork is last step of the pipe so it will yield + pipe.parent.fork(pipe.parent, len(pipe.parent) - 1) + _fork_pipeline(pipe.parent) + else: + # head of independent pipe must be iterator + pipe.evaluate_head() + assert isinstance(pipe.head, Iterator) + # add every head as source only once + if not any(i.pipe == pipe for i in extract._sources): + print("add to sources: " + pipe.name) + extract._sources.append(SourcePipeItem(pipe.head, 0, pipe)) + + for pipe in reversed(pipes): + _fork_pipeline(pipe) + + return extract + + def __next__(self) -> PipeItem: + pipe_item: Union[ResolvablePipeItem, SourcePipeItem] = None + # __next__ should call itself to remove the `while` loop and continue clauses but that may lead to stack overflows: there's no tail recursion opt in python + # https://stackoverflow.com/questions/13591970/does-python-optimize-tail-recursion (see Y combinator on how it could be emulated) + while True: + # do we need new item? + if pipe_item is None: + # process element from the futures + if len(self._futures) > 0: + pipe_item = self._resolve_futures() + # if none then take element from the newest source + if pipe_item is None: + pipe_item = self._get_source_item() + + if pipe_item is None: + if len(self._futures) == 0 and len(self._sources) == 0: + # no more elements in futures or sources + raise StopIteration() + else: + # if len(_sources + # print("waiting") + sleep(self.futures_poll_interval) + continue + + + item = pipe_item.item + # if item is iterator, then add it as a new source + if isinstance(item, Iterator): + # print(f"adding iterable {item}") + self._sources.append(SourcePipeItem(item, pipe_item.step, pipe_item.pipe)) + pipe_item = None + continue + + if isinstance(item, Awaitable) or callable(item): + # do we have a free slot or one of the slots is done? + if len(self._futures) < self.max_parallel_items or self._next_future() >= 0: + if isinstance(item, Awaitable): + future = asyncio.run_coroutine_threadsafe(item, self._ensure_async_pool()) + elif callable(item): + future = self._ensure_thread_pool().submit(item) + # print(future) + self._futures.append(FuturePipeItem(future, pipe_item.step, pipe_item.pipe)) # type: ignore + # pipe item consumed for now, request a new one + pipe_item = None + continue + else: + # print("maximum futures exceeded, waiting") + sleep(self.futures_poll_interval) + # try same item later + continue + + # if we are at the end of the pipe then yield element + # print(pipe_item) + if pipe_item.step == len(pipe_item.pipe) - 1: + # must be resolved + if isinstance(item, (Iterator, Awaitable)) or callable(pipe_item.pipe): + raise PipeItemProcessingError("Pipe item not processed", pipe_item) + # mypy not able to figure out that item was resolved + return pipe_item # type: ignore + + # advance to next step + step = pipe_item.pipe[pipe_item.step + 1] + assert callable(step) + next_item = step(item) + pipe_item = ResolvablePipeItem(next_item, pipe_item.step + 1, pipe_item.pipe) + + + def _ensure_async_pool(self) -> asyncio.AbstractEventLoop: + # lazily create async pool is separate thread + if self._async_pool: + return self._async_pool + + def start_background_loop(loop: asyncio.AbstractEventLoop) -> None: + asyncio.set_event_loop(loop) + loop.run_forever() + + self._async_pool = asyncio.new_event_loop() + self._async_pool_thread = Thread(target=start_background_loop, args=(self._async_pool,), daemon=True) + self._async_pool_thread.start() + + # start or return async pool + return self._async_pool + + def _ensure_thread_pool(self) -> ThreadPoolExecutor: + # lazily start or return thread pool + if self._thread_pool: + return self._thread_pool + + self._thread_pool = ThreadPoolExecutor(self.workers) + return self._thread_pool + + def __enter__(self) -> "PipeIterator": + return self + + def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: types.TracebackType) -> None: + + def stop_background_loop(loop: asyncio.AbstractEventLoop) -> None: + loop.stop() + + for f, _, _ in self._futures: + if not f.done(): + f.cancel() + print("stopping loop") + if self._async_pool: + self._async_pool.call_soon_threadsafe(stop_background_loop, self._async_pool) + print("joining thread") + self._async_pool_thread.join() + self._async_pool = None + self._async_pool_thread = None + if self._thread_pool: + self._thread_pool.shutdown(wait=True) + self._thread_pool = None + + def _next_future(self) -> int: + return next((i for i, val in enumerate(self._futures) if val.item.done()), -1) + + def _resolve_futures(self) -> ResolvablePipeItem: + # no futures at all + if len(self._futures) == 0: + return None + + # anything done? + idx = self._next_future() + if idx == -1: + # nothing done + return None + + future, step, pipe = self._futures.pop(idx) + + if future.cancelled(): + # get next future + return self._resolve_futures() + + if future.exception(): + raise future.exception() + + return ResolvablePipeItem(future.result(), step, pipe) + + def _get_source_item(self) -> ResolvablePipeItem: + # no more sources to iterate + if len(self._sources) == 0: + return None + + # get items from last added iterator, this makes the overall Pipe as close to FIFO as possible + gen, step, pipe = self._sources[-1] + try: + item = next(gen) + # full pipe item may be returned, this is used by ForkPipe step + # to redirect execution of an item to another pipe + if isinstance(item, ResolvablePipeItem): + return item + else: + # keep the item assigned step and pipe + return ResolvablePipeItem(item, step, pipe) + except StopIteration: + # remove empty iterator and try another source + self._sources.pop() + return self._get_source_item() + + @staticmethod + def clone_pipes(pipes: Sequence[Pipe]) -> Sequence[Pipe]: + # will clone the pipes including the dependent ones + cloned_pipes = [p.clone() for p in pipes] + cloned_pairs = {id(p): c for p, c in zip(pipes, cloned_pipes)} + + for clone in cloned_pipes: + while True: + if not clone.parent: + break + # if already a clone + if clone.parent in cloned_pairs.values(): + break + # clone if parent pipe not yet cloned + if id(clone.parent) not in cloned_pairs: + print("cloning:" + clone.parent.name) + cloned_pairs[id(clone.parent)] = clone.parent.clone() + # replace with clone + print(f"replace depends on {clone.name} to {clone.parent.name}") + clone.parent = cloned_pairs[id(clone.parent)] + # recurr with clone + clone = clone.parent + + return cloned_pipes diff --git a/dlt/extract/source.py b/dlt/extract/source.py new file mode 100644 index 0000000000..85499fda57 --- /dev/null +++ b/dlt/extract/source.py @@ -0,0 +1,331 @@ +import contextlib +from copy import deepcopy +import inspect +from collections.abc import Mapping as C_Mapping +from typing import AsyncIterable, AsyncIterator, Dict, Iterable, Iterator, List, Set, Sequence, Union, cast, Any +from typing_extensions import Self + +from dlt.common.schema import Schema +from dlt.common.schema.utils import new_table +from dlt.common.schema.typing import TPartialTableSchema, TTableSchemaColumns, TWriteDisposition +from dlt.common.typing import AnyFun, TDataItem, TDataItems +from dlt.common.configuration.container import Container +from dlt.common.pipeline import PipelineContext + +from dlt.extract.typing import TFunHintTemplate, TTableHintTemplate, TTableSchemaTemplate +from dlt.extract.pipe import FilterItem, Pipe, PipeIterator +from dlt.extract.exceptions import ( + DependentResourceIsNotCallable, InvalidDependentResourceDataTypeGeneratorFunctionRequired, InvalidParentResourceDataType, InvalidParentResourceIsAFunction, InvalidResourceDataType, InvalidResourceDataTypeFunctionNotAGenerator, + ResourceNotFoundError, CreatePipeException, DataItemRequiredForDynamicTableHints, InconsistentTableTemplate, InvalidResourceDataTypeAsync, InvalidResourceDataTypeBasic, + InvalidResourceDataTypeMultiplePipes, ResourceNameMissing, TableNameMissing) + + +class DltResourceSchema: + def __init__(self, name: str, table_schema_template: TTableSchemaTemplate = None): + # self.__name__ = name + self.name = name + self._table_name_hint_fun: TFunHintTemplate[str] = None + self._table_has_other_dynamic_hints: bool = False + self._table_schema_template: TTableSchemaTemplate = None + self._table_schema: TPartialTableSchema = None + if table_schema_template: + self.set_template(table_schema_template) + + def table_schema(self, item: TDataItem = None) -> TPartialTableSchema: + if not self._table_schema_template: + # if table template is not present, generate partial table from name + if not self._table_schema: + self._table_schema = new_table(self.name) + return self._table_schema + + def _resolve_hint(hint: TTableHintTemplate[Any]) -> Any: + if callable(hint): + return hint(item) + else: + return hint + + # if table template present and has dynamic hints, the data item must be provided + if self._table_name_hint_fun: + if item is None: + raise DataItemRequiredForDynamicTableHints(self.name) + else: + # cloned_template = deepcopy(self._table_schema_template) + return cast(TPartialTableSchema, {k: _resolve_hint(v) for k, v in self._table_schema_template.items()}) + else: + return cast(TPartialTableSchema, self._table_schema_template) + + def apply_hints( + self, + table_name: TTableHintTemplate[str] = None, + parent_table_name: TTableHintTemplate[str] = None, + write_disposition: TTableHintTemplate[TWriteDisposition] = None, + columns: TTableHintTemplate[TTableSchemaColumns] = None, + ) -> None: + t = None + if not self._table_schema_template: + # if there's no template yet, create and set new one + t = self.new_table_template(table_name, parent_table_name, write_disposition, columns) + else: + # set single hints + t = deepcopy(self._table_schema_template) + if table_name: + t["name"] = table_name + if parent_table_name: + t["parent"] = parent_table_name + if write_disposition: + t["write_disposition"] = write_disposition + if columns: + t["columns"] = columns + self.set_template(t) + + def set_template(self, table_schema_template: TTableSchemaTemplate) -> None: + # if "name" is callable in the template then the table schema requires actual data item to be inferred + name_hint = table_schema_template["name"] + if callable(name_hint): + self._table_name_hint_fun = name_hint + else: + self._table_name_hint_fun = None + # check if any other hints in the table template should be inferred from data + self._table_has_other_dynamic_hints = any(callable(v) for k, v in table_schema_template.items() if k != "name") + self._table_schema_template = table_schema_template + + @staticmethod + def new_table_template( + table_name: TTableHintTemplate[str], + parent_table_name: TTableHintTemplate[str] = None, + write_disposition: TTableHintTemplate[TWriteDisposition] = None, + columns: TTableHintTemplate[TTableSchemaColumns] = None, + ) -> TTableSchemaTemplate: + if not table_name: + raise TableNameMissing() + # create a table schema template where hints can be functions taking TDataItem + if isinstance(columns, C_Mapping): + # new_table accepts a sequence + columns = columns.values() # type: ignore + + new_template: TTableSchemaTemplate = new_table(table_name, parent_table_name, write_disposition=write_disposition, columns=columns) # type: ignore + # if any of the hints is a function then name must be as well + if any(callable(v) for k, v in new_template.items() if k != "name") and not callable(table_name): + raise InconsistentTableTemplate("Table name must be a function if any other table hint is a function") + return new_template + + +class DltResource(Iterable[TDataItems], DltResourceSchema): + def __init__(self, pipe: Pipe, table_schema_template: TTableSchemaTemplate, selected: bool): + # TODO: allow resource to take name independent from pipe name + self.name = pipe.name + self.selected = selected + self._pipe = pipe + super().__init__(self.name, table_schema_template) + + @classmethod + def from_data(cls, data: Any, name: str = None, table_schema_template: TTableSchemaTemplate = None, selected: bool = True, depends_on: Union["DltResource", Pipe] = None) -> "DltResource": + + if isinstance(data, DltResource): + return data + + if isinstance(data, Pipe): + return cls(data, table_schema_template, selected) + + if callable(data): + name = name or data.__name__ + # function must be a generator + if not inspect.isgeneratorfunction(inspect.unwrap(data)): + raise InvalidResourceDataTypeFunctionNotAGenerator(name, data, type(data)) + + # if generator, take name from it + if inspect.isgenerator(data): + name = name or data.__name__ + + # name is mandatory + if not name: + raise ResourceNameMissing() + + # several iterable types are not allowed and must be excluded right away + if isinstance(data, (AsyncIterator, AsyncIterable)): + raise InvalidResourceDataTypeAsync(name, data, type(data)) + if isinstance(data, (str, dict)): + raise InvalidResourceDataTypeBasic(name, data, type(data)) + + # check if depends_on is a valid resource + parent_pipe: Pipe = None + if depends_on: + # must be a callable with single argument + if not callable(data): + raise InvalidDependentResourceDataTypeGeneratorFunctionRequired(name, data, type(data)) + else: + if cls.is_valid_dependent_generator_function(data): + raise InvalidDependentResourceDataTypeGeneratorFunctionRequired(name, data, type(data)) + # parent resource + if isinstance(depends_on, Pipe): + parent_pipe = depends_on + elif isinstance(depends_on, DltResource): + parent_pipe = depends_on._pipe + else: + # if this is generator function provide nicer exception + if callable(depends_on): + raise InvalidParentResourceIsAFunction(name, depends_on.__name__) + else: + raise InvalidParentResourceDataType(name, depends_on, type(depends_on)) + + # create resource from iterator, iterable or generator function + if isinstance(data, (Iterable, Iterator)): + pipe = Pipe.from_iterable(name, data, parent=parent_pipe) + elif callable(data): + pipe = Pipe(name, [data], parent_pipe) + if pipe: + return cls(pipe, table_schema_template, selected) + else: + # some other data type that is not supported + raise InvalidResourceDataType(name, data, type(data), f"The data type is {type(data).__name__}") + + + def add_pipe(self, data: Any) -> None: + """Creates additional pipe for the resource from the specified data""" + # TODO: (1) self resource cannot be a dependent one (2) if data is resource both self must and it must be selected/unselected + cannot be dependent + raise InvalidResourceDataTypeMultiplePipes(self.name, data, type(data)) + + + def select(self, *table_names: Iterable[str]) -> "DltResource": + if not self._table_name_hint_fun: + raise CreatePipeException("Table name is not dynamic, table selection impossible") + + def _filter(item: TDataItem) -> bool: + return self._table_name_hint_fun(item) in table_names + + # add filtering function at the end of pipe + self._pipe.add_step(FilterItem(_filter)) + return self + + def map(self) -> None: # noqa: A003 + raise NotImplementedError() + + def flat_map(self) -> None: + raise NotImplementedError() + + def filter(self) -> None: # noqa: A003 + raise NotImplementedError() + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # make resource callable to support parametrized resources which are functions taking arguments + if self._pipe.parent: + raise DependentResourceIsNotCallable(self.name) + # pass the call parameters to the pipe's head + _data = self._pipe.head(*args, **kwargs) # type: ignore + # create new resource from extracted data + return DltResource.from_data(_data, self.name, self._table_schema_template, self.selected, self._pipe.parent) + + def __iter__(self) -> Iterator[TDataItems]: + return map(lambda item: item.item, PipeIterator.from_pipe(self._pipe)) + + def __repr__(self) -> str: + return f"DltResource {self.name} ({self._pipe._pipe_id}) at {id(self)}" + + @staticmethod + def is_valid_dependent_generator_function(f: AnyFun) -> bool: + sig = inspect.signature(f) + return len(sig.parameters) == 0 + + +class DltResourceDict(Dict[str, DltResource]): + @property + def selected(self) -> Dict[str, DltResource]: + return {k:v for k,v in self.items() if v.selected} + + @property + def pipes(self) -> List[Pipe]: + # TODO: many resources may share the same pipe so return ordered set + return [r._pipe for r in self.values()] + + @property + def selected_pipes(self) -> Sequence[Pipe]: + # TODO: many resources may share the same pipe so return ordered set + return [r._pipe for r in self.values() if r.selected] + + def select(self, *resource_names: str) -> Dict[str, DltResource]: + # checks if keys are present + for name in resource_names: + try: + self.__getitem__(name) + except KeyError: + raise ResourceNotFoundError(name, "Requested resource could not be selected because it is not present in the source.") + # set the selected flags + for resource in self.values(): + self[resource.name].selected = resource.name in resource_names + return self.selected + + def find_by_pipe(self, pipe: Pipe) -> DltResource: + # TODO: many resources may share the same pipe so return a list and also filter the resources by self._enabled_resource_names + # identify pipes by memory pointer + return next(r for r in self.values() if r._pipe._pipe_id is pipe._pipe_id) + + +class DltSource(Iterable[TDataItems]): + def __init__(self, schema: Schema, resources: Sequence[DltResource] = None) -> None: + self.name = schema.name + self._schema = schema + self._resources: DltResourceDict = DltResourceDict() + if resources: + for resource in resources: + self._add_resource(resource) + + @classmethod + def from_data(cls, schema: Schema, data: Any) -> "DltSource": + # creates source from various forms of data + if isinstance(data, DltSource): + return data + + # in case of sequence, enumerate items and convert them into resources + if isinstance(data, Sequence): + resources = [DltResource.from_data(i) for i in data] + else: + resources = [DltResource.from_data(data)] + + return cls(schema, resources) + + + @property + def resources(self) -> DltResourceDict: + return self._resources + + @property + def selected_resources(self) -> Dict[str, DltResource]: + return self._resources.selected + + @property + def schema(self) -> Schema: + return self._schema + + @schema.setter + def schema(self, value: Schema) -> None: + self._schema = value + + def discover_schema(self) -> Schema: + # extract tables from all resources and update internal schema + for r in self._resources.values(): + # names must be normalized here + with contextlib.suppress(DataItemRequiredForDynamicTableHints): + partial_table = self._schema.normalize_table_identifiers(r.table_schema()) + self._schema.update_schema(partial_table) + return self._schema + + def with_resources(self, *resource_names: str) -> "DltSource": + self._resources.select(*resource_names) + return self + + + def run(self, destination: Any) -> Any: + return Container()[PipelineContext].pipeline().run(source=self, destination=destination) + + def _add_resource(self, resource: DltResource) -> None: + if resource.name in self._resources: + # for resources with the same name try to add the resource as an another pipe + self._resources[resource.name].add_pipe(resource) + else: + self._resources[resource.name] = resource + + def __iter__(self) -> Iterator[TDataItems]: + return map(lambda item: item.item, PipeIterator.from_pipes(self._resources.selected_pipes)) + + def __repr__(self) -> str: + return f"DltSource {self.name} at {id(self)}" diff --git a/dlt/extract/typing.py b/dlt/extract/typing.py new file mode 100644 index 0000000000..b7b33b6c65 --- /dev/null +++ b/dlt/extract/typing.py @@ -0,0 +1,22 @@ +from typing import Callable, TypedDict, TypeVar, Union, List, Awaitable + +from dlt.common.typing import TDataItem, TDataItems +from dlt.common.schema.typing import TTableSchemaColumns, TWriteDisposition + + +TDeferredDataItems = Callable[[], TDataItems] +TAwaitableDataItems = Awaitable[TDataItems] +TPipedDataItems = Union[TDataItems, TDeferredDataItems, TAwaitableDataItems] + +TDynHintType = TypeVar("TDynHintType") +TFunHintTemplate = Callable[[TDataItem], TDynHintType] +TTableHintTemplate = Union[TDynHintType, TFunHintTemplate[TDynHintType]] + + +class TTableSchemaTemplate(TypedDict, total=False): + name: TTableHintTemplate[str] + # description: TTableHintTemplate[str] + write_disposition: TTableHintTemplate[TWriteDisposition] + # table_sealed: Optional[bool] + parent: TTableHintTemplate[str] + columns: TTableHintTemplate[TTableSchemaColumns] diff --git a/dlt/helpers/pandas.py b/dlt/helpers/pandas.py index 3d9dcb1c59..1fb5249929 100644 --- a/dlt/helpers/pandas.py +++ b/dlt/helpers/pandas.py @@ -1,7 +1,7 @@ from typing import Any from dlt.pipeline.exceptions import MissingDependencyException -from dlt.load.client_base import SqlClientBase +from dlt.load.sql_client import SqlClientBase try: import pandas as pd diff --git a/dlt/helpers/streamlit.py b/dlt/helpers/streamlit.py index ed1c3f624e..383959d55f 100644 --- a/dlt/helpers/streamlit.py +++ b/dlt/helpers/streamlit.py @@ -1,170 +1,170 @@ - -import os -import tomlkit -from tomlkit.container import Container as TomlContainer -from typing import cast -from copy import deepcopy - -from dlt.pipeline import Pipeline -from dlt.pipeline.typing import credentials_from_dict -from dlt.pipeline.exceptions import MissingDependencyException, PipelineException -from dlt.helpers.pandas import query_results_to_df, pd -from dlt.common.configuration.run_configuration import BaseConfiguration, CredentialsConfiguration -from dlt.common.utils import dict_remove_nones_in_place - -try: - import streamlit as st - from streamlit import SECRETS_FILE_LOC, secrets -except ImportError: - raise MissingDependencyException("DLT Streamlit Helpers", ["streamlit"], "DLT Helpers for Streamlit should be run within a streamlit app.") - - -def restore_pipeline() -> Pipeline: - """Restores Pipeline instance and associated credentials from Streamlit secrets - - Current implementation requires that pipeline working dir is available at the location saved in secrets. - - Raises: - PipelineBackupNotFound: Raised when pipeline backup is not available - CannotRestorePipelineException: Raised when pipeline working dir is not found or invalid - - Returns: - Pipeline: Instance of pipeline with attached credentials - """ - if "dlt" not in secrets: - raise PipelineException("You must backup pipeline to Streamlit first") - dlt_cfg = secrets["dlt"] - credentials = deepcopy(dict(dlt_cfg["destination"])) - if "DEFAULT_SCHEMA_NAME" in credentials: - del credentials["DEFAULT_SCHEMA_NAME"] - credentials.update(dlt_cfg["credentials"]) - pipeline = Pipeline(dlt_cfg["pipeline_name"]) - pipeline.restore_pipeline(credentials_from_dict(credentials), dlt_cfg["working_dir"]) - return pipeline - - -def backup_pipeline(pipeline: Pipeline) -> None: - """Backups pipeline state to the `secrets.toml` of the Streamlit app. - - Pipeline credentials and working directory will be added to the Streamlit `secrets` file. This allows to access query the data loaded to the destination and - access definitions of the inferred schemas. See `restore_pipeline` and `write_data_explorer_page` functions in the same module. - - Args: - pipeline (Pipeline): Pipeline instance, typically restored with `restore_pipeline` - """ - # save pipeline state to project .config - # config_file_name = file_util.get_project_streamlit_file_path("config.toml") - - # save credentials to secrets - if os.path.isfile(SECRETS_FILE_LOC): - with open(SECRETS_FILE_LOC, "r", encoding="utf-8") as f: - # use whitespace preserving parser - secrets = tomlkit.load(f) - else: - secrets = tomlkit.document() - - # save general settings - secrets["dlt"] = { - "working_dir": pipeline.working_dir, - "pipeline_name": pipeline.pipeline_name - } - - # get client config - # TODO: pipeline api v2 should provide a direct method to get configurations - CONFIG: BaseConfiguration = pipeline._loader_instance.load_client_cls.CONFIG # type: ignore - CREDENTIALS: CredentialsConfiguration = pipeline._loader_instance.load_client_cls.CREDENTIALS # type: ignore - - # save client config - # print(dict_remove_nones_in_place(CONFIG.as_dict(lowercase=False))) - dlt_c = cast(TomlContainer, secrets["dlt"]) - dlt_c["destination"] = dict_remove_nones_in_place(CONFIG.as_dict(lowercase=False)) - dlt_c["credentials"] = dict_remove_nones_in_place(CREDENTIALS.as_dict(lowercase=False)) - - with open(SECRETS_FILE_LOC, "w", encoding="utf-8") as f: - # use whitespace preserving parser - tomlkit.dump(secrets, f) - - -def write_data_explorer_page(pipeline: Pipeline, schema_name: str = None, show_dlt_tables: bool = False, example_query: str = "", show_charts: bool = True) -> None: - """Writes Streamlit app page with a schema and live data preview. - - Args: - pipeline (Pipeline): Pipeline instance to use. - schema_name (str, optional): Name of the schema to display. If None, default schema is used. - show_dlt_tables (bool, optional): Should show DLT internal tables. Defaults to False. - example_query (str, optional): Example query to be displayed in the SQL Query box. - show_charts (bool, optional): Should automatically show charts for the queries from SQL Query box. Defaults to True. - - Raises: - MissingDependencyException: Raised when a particular python dependency is not installed - """ - @st.experimental_memo(ttl=600) - def run_query(query: str) -> pd.DataFrame: - # dlt pipeline exposes configured sql client that (among others) let's you make queries against the warehouse - with pipeline.sql_client(schema_name) as client: - df = query_results_to_df(client, query) - return df - - if schema_name: - schema = pipeline.get_schema(schema_name) - else: - schema = pipeline.get_default_schema() - st.title(f"Available tables in {schema.name} schema") - # st.text(schema.to_pretty_yaml()) - - for table in schema.all_tables(with_dlt_tables=show_dlt_tables): - table_name = table["name"] - st.header(table_name) - if "description" in table: - st.text(table["description"]) - if "parent" in table: - st.text("Parent table: " + table["parent"]) - - # table schema contains various hints (like clustering or partition options) that we do not want to show in basic view - essentials_f = lambda c: {k:v for k, v in c.items() if k in ["name", "data_type", "nullable"]} - - st.table(map(essentials_f, table["columns"].values())) - # add a button that when pressed will show the full content of a table - if st.button("SHOW DATA", key=table_name): - st.text(f"Full {table_name} table content") - st.dataframe(run_query(f"SELECT * FROM {table_name}")) - - st.title("Run your query") - sql_query = st.text_area("Enter your SQL query", value=example_query) - if st.button("Run Query"): - if sql_query: - st.text("Results of a query") - try: - # run the query from the text area - df = run_query(sql_query) - # and display the results - st.dataframe(df) - - try: - # now if the dataset has supported shape try to display the bar or altair chart - if df.dtypes.shape[0] == 1 and show_charts: - # try barchart - st.bar_chart(df) - if df.dtypes.shape[0] == 2 and show_charts: - - # try to import altair charts - try: - import altair as alt - except ImportError: - raise MissingDependencyException( - "DLT Streamlit Helpers", - ["altair"], - "DLT Helpers for Streamlit should be run within a streamlit app." - ) - - # try altair - bar_chart = alt.Chart(df).mark_bar().encode( - x=f'{df.columns[1]}:Q', - y=alt.Y(f'{df.columns[0]}:N', sort='-x') - ) - st.altair_chart(bar_chart, use_container_width=True) - except Exception as ex: - st.error(f"Chart failed due to: {ex}") - except Exception as ex: - st.text("Exception when running query") - st.exception(ex) +# import os +# import tomlkit +# from tomlkit.container import Container as TomlContainer +# from typing import cast +# from copy import deepcopy + +# from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration +# from dlt.common.utils import dict_remove_nones_in_place + +# from dlt.pipeline import Pipeline +# from dlt.pipeline.typing import credentials_from_dict +# from dlt.pipeline.exceptions import MissingDependencyException, PipelineException +# from dlt.helpers.pandas import query_results_to_df, pd + +# try: +# import streamlit as st +# from streamlit import SECRETS_FILE_LOC, secrets +# except ImportError: +# raise MissingDependencyException("DLT Streamlit Helpers", ["streamlit"], "DLT Helpers for Streamlit should be run within a streamlit app.") + + +# def restore_pipeline() -> Pipeline: +# """Restores Pipeline instance and associated credentials from Streamlit secrets + +# Current implementation requires that pipeline working dir is available at the location saved in secrets. + +# Raises: +# PipelineBackupNotFound: Raised when pipeline backup is not available +# CannotRestorePipelineException: Raised when pipeline working dir is not found or invalid + +# Returns: +# Pipeline: Instance of pipeline with attached credentials +# """ +# if "dlt" not in secrets: +# raise PipelineException("You must backup pipeline to Streamlit first") +# dlt_cfg = secrets["dlt"] +# credentials = deepcopy(dict(dlt_cfg["destination"])) +# if "default_schema_name" in credentials: +# del credentials["default_schema_name"] +# credentials.update(dlt_cfg["credentials"]) +# pipeline = Pipeline(dlt_cfg["pipeline_name"]) +# pipeline.restore_pipeline(credentials_from_dict(credentials), dlt_cfg["working_dir"]) +# return pipeline + + +# def backup_pipeline(pipeline: Pipeline) -> None: +# """Backups pipeline state to the `secrets.toml` of the Streamlit app. + +# Pipeline credentials and working directory will be added to the Streamlit `secrets` file. This allows to access query the data loaded to the destination and +# access definitions of the inferred schemas. See `restore_pipeline` and `write_data_explorer_page` functions in the same module. + +# Args: +# pipeline (Pipeline): Pipeline instance, typically restored with `restore_pipeline` +# """ +# # save pipeline state to project .config +# # config_file_name = file_util.get_project_streamlit_file_path("config.toml") + +# # save credentials to secrets +# if os.path.isfile(SECRETS_FILE_LOC): +# with open(SECRETS_FILE_LOC, "r", encoding="utf-8") as f: +# # use whitespace preserving parser +# secrets_ = tomlkit.load(f) +# else: +# secrets_ = tomlkit.document() + +# # save general settings +# secrets_["dlt"] = { +# "working_dir": pipeline.working_dir, +# "pipeline_name": pipeline.pipeline_name +# } + +# # get client config +# # TODO: pipeline api v2 should provide a direct method to get configurations +# CONFIG: BaseConfiguration = pipeline._loader_instance.load_client_cls.CONFIG # type: ignore +# CREDENTIALS: CredentialsConfiguration = pipeline._loader_instance.load_client_cls.CREDENTIALS # type: ignore + +# # save client config +# # print(dict_remove_nones_in_place(CONFIG.as_dict(lowercase=False))) +# dlt_c = cast(TomlContainer, secrets_["dlt"]) +# dlt_c["destination"] = dict_remove_nones_in_place(dict(CONFIG)) +# dlt_c["credentials"] = dict_remove_nones_in_place(dict(CREDENTIALS)) + +# with open(SECRETS_FILE_LOC, "w", encoding="utf-8") as f: +# # use whitespace preserving parser +# tomlkit.dump(secrets_, f) + + +# def write_data_explorer_page(pipeline: Pipeline, schema_name: str = None, show_dlt_tables: bool = False, example_query: str = "", show_charts: bool = True) -> None: +# """Writes Streamlit app page with a schema and live data preview. + +# Args: +# pipeline (Pipeline): Pipeline instance to use. +# schema_name (str, optional): Name of the schema to display. If None, default schema is used. +# show_dlt_tables (bool, optional): Should show DLT internal tables. Defaults to False. +# example_query (str, optional): Example query to be displayed in the SQL Query box. +# show_charts (bool, optional): Should automatically show charts for the queries from SQL Query box. Defaults to True. + +# Raises: +# MissingDependencyException: Raised when a particular python dependency is not installed +# """ +# @st.experimental_memo(ttl=600) +# def run_query(query: str) -> pd.DataFrame: +# # dlt pipeline exposes configured sql client that (among others) let's you make queries against the warehouse +# with pipeline.sql_client(schema_name) as client: +# df = query_results_to_df(client, query) +# return df + +# if schema_name: +# schema = pipeline.get_schema(schema_name) +# else: +# schema = pipeline.get_default_schema() +# st.title(f"Available tables in {schema.name} schema") +# # st.text(schema.to_pretty_yaml()) + +# for table in schema.all_tables(with_dlt_tables=show_dlt_tables): +# table_name = table["name"] +# st.header(table_name) +# if "description" in table: +# st.text(table["description"]) +# if "parent" in table: +# st.text("Parent table: " + table["parent"]) + +# # table schema contains various hints (like clustering or partition options) that we do not want to show in basic view +# essentials_f = lambda c: {k:v for k, v in c.items() if k in ["name", "data_type", "nullable"]} + +# st.table(map(essentials_f, table["columns"].values())) +# # add a button that when pressed will show the full content of a table +# if st.button("SHOW DATA", key=table_name): +# st.text(f"Full {table_name} table content") +# st.dataframe(run_query(f"SELECT * FROM {table_name}")) + +# st.title("Run your query") +# sql_query = st.text_area("Enter your SQL query", value=example_query) +# if st.button("Run Query"): +# if sql_query: +# st.text("Results of a query") +# try: +# # run the query from the text area +# df = run_query(sql_query) +# # and display the results +# st.dataframe(df) + +# try: +# # now if the dataset has supported shape try to display the bar or altair chart +# if df.dtypes.shape[0] == 1 and show_charts: +# # try barchart +# st.bar_chart(df) +# if df.dtypes.shape[0] == 2 and show_charts: + +# # try to import altair charts +# try: +# import altair as alt +# except ImportError: +# raise MissingDependencyException( +# "DLT Streamlit Helpers", +# ["altair"], +# "DLT Helpers for Streamlit should be run within a streamlit app." +# ) + +# # try altair +# bar_chart = alt.Chart(df).mark_bar().encode( +# x=f'{df.columns[1]}:Q', +# y=alt.Y(f'{df.columns[0]}:N', sort='-x') +# ) +# st.altair_chart(bar_chart, use_container_width=True) +# except Exception as ex: +# st.error(f"Chart failed due to: {ex}") +# except Exception as ex: +# st.text("Exception when running query") +# st.exception(ex) diff --git a/dlt/load/__init__.py b/dlt/load/__init__.py index 28501cffe5..0a6c97ed3d 100644 --- a/dlt/load/__init__.py +++ b/dlt/load/__init__.py @@ -1,2 +1 @@ -from dlt._version import loader_version as __version__ from dlt.load.load import Load diff --git a/dlt/load/bigquery/__init__.py b/dlt/load/bigquery/__init__.py index e69de29bb2..b14aefc661 100644 --- a/dlt/load/bigquery/__init__.py +++ b/dlt/load/bigquery/__init__.py @@ -0,0 +1,39 @@ +from typing import Type + +from dlt.common.schema.schema import Schema +from dlt.common.typing import ConfigValue +from dlt.common.configuration import with_config +from dlt.common.destination import DestinationCapabilitiesContext, JobClientBase, DestinationClientConfiguration + +from dlt.load.bigquery.configuration import BigQueryClientConfiguration + + +@with_config(spec=BigQueryClientConfiguration, namespaces=("destination", "bigquery",)) +def _configure(config: BigQueryClientConfiguration = ConfigValue) -> BigQueryClientConfiguration: + return config + + +def capabilities() -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.update({ + "preferred_loader_file_format": "jsonl", + "supported_loader_file_formats": ["jsonl"], + "max_identifier_length": 1024, + "max_column_length": 300, + "max_query_length": 1024 * 1024, + "is_max_query_length_in_bytes": False, + "max_text_data_type_length": 10 * 1024 * 1024, + "is_max_text_data_type_length_in_bytes": True + }) + return caps + + +def client(schema: Schema, initial_config: DestinationClientConfiguration = ConfigValue) -> JobClientBase: + # import client when creating instance so capabilities and config specs can be accessed without dependencies installed + from dlt.load.bigquery.bigquery import BigQueryClient + + return BigQueryClient(schema, _configure(initial_config)) # type: ignore + + +def spec() -> Type[DestinationClientConfiguration]: + return BigQueryClientConfiguration \ No newline at end of file diff --git a/dlt/load/bigquery/client.py b/dlt/load/bigquery/bigquery.py similarity index 83% rename from dlt/load/bigquery/client.py rename to dlt/load/bigquery/bigquery.py index 27a73efea4..db73a0b746 100644 --- a/dlt/load/bigquery/client.py +++ b/dlt/load/bigquery/bigquery.py @@ -1,6 +1,7 @@ from pathlib import Path from contextlib import contextmanager -from typing import Any, AnyStr, Dict, Iterator, List, Optional, Sequence, Tuple, Type +from typing import Any, AnyStr, Dict, Iterator, List, Optional, Sequence, Tuple +from dlt.common.storages.file_storage import FileStorage import google.cloud.bigquery as bigquery # noqa: I250 from google.cloud.bigquery.dbapi import Connection as DbApiConnection from google.cloud import exceptions as gcp_exceptions @@ -12,15 +13,18 @@ from dlt.common.typing import StrAny from dlt.common.schema.typing import TTableSchema, TWriteDisposition from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE -from dlt.common.configuration import GcpClientCredentials -from dlt.common.dataset_writers import escape_bigquery_identifier +from dlt.common.configuration.specs import GcpClientCredentials +from dlt.common.destination import DestinationCapabilitiesContext, TLoadJobStatus, LoadJob +from dlt.common.data_writers import escape_bigquery_identifier from dlt.common.schema import TColumnSchema, TDataType, Schema, TTableSchemaColumns -from dlt.load.typing import LoadJobStatus, DBCursor, TLoaderCapabilities -from dlt.load.client_base import JobClientBase, SqlClientBase, SqlJobClientBase, LoadJob +from dlt.load.typing import DBCursor +from dlt.load.sql_client import SqlClientBase +from dlt.load.job_client_impl import SqlJobClientBase from dlt.load.exceptions import LoadClientSchemaWillNotUpdate, LoadJobNotExistsException, LoadJobServerTerminalException, LoadUnknownTableException -from dlt.load.bigquery.configuration import BigQueryClientConfiguration, configuration +from dlt.load.bigquery import capabilities +from dlt.load.bigquery.configuration import BigQueryClientConfiguration SCT_TO_BQT: Dict[TDataType, str] = { @@ -52,21 +56,21 @@ class BigQuerySqlClient(SqlClientBase[bigquery.Client]): - def __init__(self, default_dataset_name: str, CREDENTIALS: Type[GcpClientCredentials]) -> None: + def __init__(self, default_dataset_name: str, credentials: GcpClientCredentials) -> None: self._client: bigquery.Client = None - self.C = CREDENTIALS + self.credentials = credentials super().__init__(default_dataset_name) - self.default_retry = bigquery.DEFAULT_RETRY.with_deadline(CREDENTIALS.RETRY_DEADLINE) + self.default_retry = bigquery.DEFAULT_RETRY.with_deadline(credentials.retry_deadline) self.default_query = bigquery.QueryJobConfig(default_dataset=self.fully_qualified_dataset_name()) def open_connection(self) -> None: # use default credentials if partial config - if self.C.__is_partial__: + if not self.credentials.is_resolved(): credentials = None else: - credentials = service_account.Credentials.from_service_account_info(self.C.as_credentials()) - self._client = bigquery.Client(self.C.PROJECT_ID, credentials=credentials, location=self.C.LOCATION) + credentials = service_account.Credentials.from_service_account_info(self.credentials.to_native_representation()) + self._client = bigquery.Client(self.credentials.project_id, credentials=credentials, location=self.credentials.location) def close_connection(self) -> None: if self._client: @@ -79,7 +83,7 @@ def native_connection(self) -> bigquery.Client: def has_dataset(self) -> bool: try: - self._client.get_dataset(self.fully_qualified_dataset_name(), retry=self.default_retry, timeout=self.C.HTTP_TIMEOUT) + self._client.get_dataset(self.fully_qualified_dataset_name(), retry=self.default_retry, timeout=self.credentials.http_timeout) return True except gcp_exceptions.NotFound: return False @@ -89,7 +93,7 @@ def create_dataset(self) -> None: self.fully_qualified_dataset_name(), exists_ok=False, retry=self.default_retry, - timeout=self.C.HTTP_TIMEOUT + timeout=self.credentials.http_timeout ) def drop_dataset(self) -> None: @@ -98,7 +102,7 @@ def drop_dataset(self) -> None: not_found_ok=True, delete_contents=True, retry=self.default_retry, - timeout=self.C.HTTP_TIMEOUT + timeout=self.credentials.http_timeout ) def execute_sql(self, sql: AnyStr, *args: Any, **kwargs: Any) -> Optional[Sequence[Sequence[Any]]]: @@ -106,7 +110,7 @@ def execute_sql(self, sql: AnyStr, *args: Any, **kwargs: Any) -> Optional[Sequen def_kwargs = { "job_config": self.default_query, "job_retry": self.default_retry, - "timeout": self.C.HTTP_TIMEOUT + "timeout": self.credentials.http_timeout } kwargs = {**def_kwargs, **(kwargs or {})} results = self._client.query(sql, *args, **kwargs).result() @@ -135,19 +139,19 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[D conn.close() def fully_qualified_dataset_name(self) -> str: - return f"{self.C.PROJECT_ID}.{self.default_dataset_name}" + return f"{self.credentials.project_id}.{self.default_dataset_name}" class BigQueryLoadJob(LoadJob): - def __init__(self, file_name: str, bq_load_job: bigquery.LoadJob, CONFIG: Type[GcpClientCredentials]) -> None: + def __init__(self, file_name: str, bq_load_job: bigquery.LoadJob, credentials: GcpClientCredentials) -> None: self.bq_load_job = bq_load_job - self.C = CONFIG - self.default_retry = bigquery.DEFAULT_RETRY.with_deadline(CONFIG.RETRY_DEADLINE) + self.credentials = credentials + self.default_retry = bigquery.DEFAULT_RETRY.with_deadline(credentials.retry_deadline) super().__init__(file_name) - def status(self) -> LoadJobStatus: + def status(self) -> TLoadJobStatus: # check server if done - done = self.bq_load_job.done(retry=self.default_retry, timeout=self.C.HTTP_TIMEOUT) + done = self.bq_load_job.done(retry=self.default_retry, timeout=self.credentials.http_timeout) if done: # rows processed if self.bq_load_job.output_rows is not None and self.bq_load_job.error_result is None: @@ -183,27 +187,30 @@ def exception(self) -> str: class BigQueryClient(SqlJobClientBase): - CONFIG: Type[BigQueryClientConfiguration] = None - CREDENTIALS: Type[GcpClientCredentials] = None + # CONFIG: BigQueryClientConfiguration = None + # CREDENTIALS: GcpClientCredentials = None - def __init__(self, schema: Schema) -> None: + def __init__(self, schema: Schema, config: BigQueryClientConfiguration) -> None: sql_client = BigQuerySqlClient( - schema.normalize_make_dataset_name(self.CONFIG.DEFAULT_DATASET, self.CONFIG.DEFAULT_SCHEMA_NAME, schema.name), - self.CREDENTIALS + schema.normalize_make_dataset_name(config.dataset_name, config.default_schema_name, schema.name), + config.credentials ) - super().__init__(schema, sql_client) + super().__init__(schema, config, sql_client) + self.config: BigQueryClientConfiguration = config self.sql_client: BigQuerySqlClient = sql_client - def initialize_storage(self) -> None: + def initialize_storage(self, wipe_data: bool = False) -> None: + if wipe_data: + raise NotImplementedError() if not self.sql_client.has_dataset(): self.sql_client.create_dataset() def restore_file_load(self, file_path: str) -> LoadJob: try: return BigQueryLoadJob( - JobClientBase.get_file_name_from_file_path(file_path), + FileStorage.get_file_name_from_file_path(file_path), self._retrieve_load_job(file_path), - self.CREDENTIALS + self.config.credentials #self.sql_client.native_connection() ) except api_core_exceptions.GoogleAPICallError as gace: @@ -218,9 +225,9 @@ def restore_file_load(self, file_path: str) -> LoadJob: def start_file_load(self, table: TTableSchema, file_path: str) -> LoadJob: try: return BigQueryLoadJob( - JobClientBase.get_file_name_from_file_path(file_path), + FileStorage.get_file_name_from_file_path(file_path), self._create_load_job(table["name"], table["write_disposition"], file_path), - self.CREDENTIALS + self.config.credentials ) except api_core_exceptions.GoogleAPICallError as gace: reason = self._get_reason_from_errors(gace) @@ -296,7 +303,9 @@ def _get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns schema_table: TTableSchemaColumns = {} try: table = self.sql_client.native_connection.get_table( - self.sql_client.make_qualified_table_name(table_name), retry=self.sql_client.default_retry, timeout=self.CREDENTIALS.HTTP_TIMEOUT + self.sql_client.make_qualified_table_name(table_name), + retry=self.sql_client.default_retry, + timeout=self.config.credentials.http_timeout ) partition_field = table.time_partitioning.field if table.time_partitioning else None for c in table.schema: @@ -329,12 +338,13 @@ def _create_load_job(self, table_name: str, write_disposition: TWriteDisposition ) with open(file_path, "rb") as f: - return self.sql_client.native_connection.load_table_from_file(f, - self.sql_client.make_qualified_table_name(table_name), - job_id=job_id, - job_config=job_config, - timeout=self.CREDENTIALS.HTTP_TIMEOUT - ) + return self.sql_client.native_connection.load_table_from_file( + f, + self.sql_client.make_qualified_table_name(table_name), + job_id=job_id, + job_config=job_config, + timeout=self.config.credentials.file_upload_timeout + ) def _retrieve_load_job(self, file_path: str) -> bigquery.LoadJob: job_id = BigQueryClient._get_job_id_from_file_path(file_path) @@ -367,18 +377,5 @@ def _bq_t_to_sc_t(bq_t: str, precision: Optional[int], scale: Optional[int]) -> return BQT_TO_SCT.get(bq_t, "text") @classmethod - def capabilities(cls) -> TLoaderCapabilities: - return { - "preferred_loader_file_format": "jsonl", - "supported_loader_file_formats": ["jsonl"], - "max_identifier_length": 1024, - "max_column_length": 300 - } - - @classmethod - def configure(cls, initial_values: StrAny = None) -> Tuple[Type[BigQueryClientConfiguration], Type[GcpClientCredentials]]: - cls.CONFIG, cls.CREDENTIALS = configuration(initial_values=initial_values) - return cls.CONFIG, cls.CREDENTIALS - - -CLIENT = BigQueryClient + def capabilities(cls) -> DestinationCapabilitiesContext: + return capabilities() diff --git a/dlt/load/bigquery/configuration.py b/dlt/load/bigquery/configuration.py index 8bca223db3..496d9b0f05 100644 --- a/dlt/load/bigquery/configuration.py +++ b/dlt/load/bigquery/configuration.py @@ -1,37 +1,25 @@ -from typing import Tuple, Type +from typing import Optional from google.auth import default as default_credentials from google.auth.exceptions import DefaultCredentialsError -from dlt.common.typing import StrAny -from dlt.common.configuration import make_configuration, GcpClientCredentials -from dlt.common.configuration.exceptions import ConfigEntryMissingException +from dlt.common.configuration import configspec +from dlt.common.configuration.specs import GcpClientCredentials +from dlt.common.configuration.exceptions import ConfigFieldMissingException +from dlt.common.destination import DestinationClientDwhConfiguration -from dlt.load.configuration import LoaderClientDwhConfiguration +@configspec(init=True) +class BigQueryClientConfiguration(DestinationClientDwhConfiguration): + destination_name: str = "bigquery" + credentials: Optional[GcpClientCredentials] = None -class BigQueryClientConfiguration(LoaderClientDwhConfiguration): - CLIENT_TYPE: str = "bigquery" - - -def configuration(initial_values: StrAny = None) -> Tuple[Type[BigQueryClientConfiguration], Type[GcpClientCredentials]]: - - def maybe_partial_credentials() -> Type[GcpClientCredentials]: - try: - return make_configuration(GcpClientCredentials, GcpClientCredentials, initial_values=initial_values) - except ConfigEntryMissingException as cfex: - # if config is missing check if credentials can be obtained from defaults + def check_integrity(self) -> None: + if not self.credentials.is_resolved(): + # if config is missing check if credentials can be obtained from defaults try: _, project_id = default_credentials() - # if so then return partial so we can access timeouts - C_PARTIAL = make_configuration(GcpClientCredentials, GcpClientCredentials, initial_values=initial_values, accept_partial = True) # set the project id - it needs to be known by the client - C_PARTIAL.PROJECT_ID = C_PARTIAL.PROJECT_ID or project_id - return C_PARTIAL + self.credentials.project_id = self.credentials.project_id or project_id except DefaultCredentialsError: - raise cfex - - return ( - make_configuration(BigQueryClientConfiguration, BigQueryClientConfiguration, initial_values=initial_values), - # allow partial credentials so the client can fallback to default credentials - maybe_partial_credentials() - ) + # re-raise preventing exception + raise self.credentials.__exception__ diff --git a/dlt/load/client_base.py b/dlt/load/client_base.py deleted file mode 100644 index d58671ae27..0000000000 --- a/dlt/load/client_base.py +++ /dev/null @@ -1,227 +0,0 @@ -from abc import ABC, abstractmethod -from contextlib import contextmanager -from types import TracebackType -from typing import Any, ContextManager, Generic, Iterator, List, Optional, Sequence, Tuple, Type, AnyStr -from pathlib import Path - -from dlt.common import pendulum, logger -from dlt.common.configuration import BaseConfiguration, CredentialsConfiguration -from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns -from dlt.common.schema.typing import TTableSchema -from dlt.common.typing import StrAny - -from dlt.load.typing import LoadJobStatus, TNativeConn, TLoaderCapabilities, DBCursor -from dlt.load.exceptions import LoadClientSchemaVersionCorrupted - - -class LoadJob: - """Represents a job that loads a single file - - Each job starts in "running" state and ends in one of terminal states: "retry", "failed" or "completed". - Each job is uniquely identified by a file name. The file is guaranteed to exist in "running" state. In terminal state, the file may not be present. - In "running" state, the loader component periodically gets the state via `status()` method. When terminal state is reached, load job is discarded and not called again. - `exception` method is called to get error information in "failed" and "retry" states. - - The `__init__` method is responsible to put the Job in "running" state. It may raise `LoadClientTerminalException` and `LoadClientTransientException` tp - immediately transition job into "failed" or "retry" state respectively. - """ - def __init__(self, file_name: str) -> None: - """ - File name is also a job id (or job id is deterministically derived) so it must be globally unique - """ - self._file_name = file_name - - @abstractmethod - def status(self) -> LoadJobStatus: - pass - - @abstractmethod - def file_name(self) -> str: - pass - - @abstractmethod - def exception(self) -> str: - pass - - -class LoadEmptyJob(LoadJob): - def __init__(self, file_name: str, status: LoadJobStatus, exception: str = None) -> None: - self._status = status - self._exception = exception - super().__init__(file_name) - - def status(self) -> LoadJobStatus: - return self._status - - def file_name(self) -> str: - return self._file_name - - def exception(self) -> str: - return self._exception - - -class JobClientBase(ABC): - def __init__(self, schema: Schema) -> None: - self.schema = schema - - @abstractmethod - def initialize_storage(self) -> None: - pass - - @abstractmethod - def update_storage_schema(self) -> None: - pass - - @abstractmethod - def start_file_load(self, table: TTableSchema, file_path: str) -> LoadJob: - pass - - @abstractmethod - def restore_file_load(self, file_path: str) -> LoadJob: - pass - - @abstractmethod - def complete_load(self, load_id: str) -> None: - pass - - @abstractmethod - def __enter__(self) -> "JobClientBase": - pass - - @abstractmethod - def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType) -> None: - pass - - @staticmethod - def get_file_name_from_file_path(file_path: str) -> str: - return Path(file_path).name - - @staticmethod - def make_job_with_status(file_path: str, status: LoadJobStatus, message: str = None) -> LoadJob: - return LoadEmptyJob(JobClientBase.get_file_name_from_file_path(file_path), status, exception=message) - - @staticmethod - def make_absolute_path(file_path: str) -> str: - return str(Path(file_path).absolute()) - - @classmethod - @abstractmethod - def capabilities(cls) -> TLoaderCapabilities: - pass - - @classmethod - @abstractmethod - def configure(cls, initial_values: StrAny = None) -> Tuple[Type[BaseConfiguration], Type[CredentialsConfiguration]]: - pass - - -class SqlClientBase(ABC, Generic[TNativeConn]): - def __init__(self, default_dataset_name: str) -> None: - if not default_dataset_name: - raise ValueError(default_dataset_name) - self.default_dataset_name = default_dataset_name - - @abstractmethod - def open_connection(self) -> None: - pass - - @abstractmethod - def close_connection(self) -> None: - pass - - def __enter__(self) -> "SqlClientBase[TNativeConn]": - self.open_connection() - return self - - def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType) -> None: - self.close_connection() - - @abstractmethod - def native_connection(self) -> TNativeConn: - pass - - @abstractmethod - def has_dataset(self) -> bool: - pass - - @abstractmethod - def create_dataset(self) -> None: - pass - - @abstractmethod - def drop_dataset(self) -> None: - pass - - @abstractmethod - def execute_sql(self, sql: AnyStr, *args: Any, **kwargs: Any) -> Optional[Sequence[Sequence[Any]]]: - pass - - @abstractmethod - def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> ContextManager[DBCursor]: - pass - - @abstractmethod - def fully_qualified_dataset_name(self) -> str: - pass - - def make_qualified_table_name(self, table_name: str) -> str: - return f"{self.fully_qualified_dataset_name()}.{table_name}" - - @contextmanager - def with_alternative_dataset_name(self, dataset_name: str) -> Iterator["SqlClientBase[TNativeConn]"]: - current_dataset_name = self.default_dataset_name - try: - self.default_dataset_name = dataset_name - yield self - finally: - # restore previous dataset name - self.default_dataset_name = current_dataset_name - - -class SqlJobClientBase(JobClientBase): - def __init__(self, schema: Schema, sql_client: SqlClientBase[TNativeConn]) -> None: - super().__init__(schema) - self.sql_client = sql_client - - def update_storage_schema(self) -> None: - storage_version = self._get_schema_version_from_storage() - if storage_version < self.schema.stored_version: - for sql in self._build_schema_update_sql(): - self.sql_client.execute_sql(sql) - self._update_schema_version(self.schema.stored_version) - - def complete_load(self, load_id: str) -> None: - name = self.sql_client.make_qualified_table_name(Schema.LOADS_TABLE_NAME) - now_ts = str(pendulum.now()) - self.sql_client.execute_sql(f"INSERT INTO {name}(load_id, status, inserted_at) VALUES('{load_id}', 0, '{now_ts}');") - - def __enter__(self) -> "SqlJobClientBase": - self.sql_client.open_connection() - return self - - def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType) -> None: - self.sql_client.close_connection() - - @abstractmethod - def _build_schema_update_sql(self) -> List[str]: - pass - - def _create_table_update(self, table_name: str, storage_table: TTableSchemaColumns) -> Sequence[TColumnSchema]: - # compare table with stored schema and produce delta - updates = self.schema.get_schema_update_for(table_name, storage_table) - logger.info(f"Found {len(updates)} updates for {table_name} in {self.schema.name}") - return updates - - def _get_schema_version_from_storage(self) -> int: - name = self.sql_client.make_qualified_table_name(Schema.VERSION_TABLE_NAME) - rows = self.sql_client.execute_sql(f"SELECT {Schema.VERSION_COLUMN_NAME} FROM {name} ORDER BY inserted_at DESC LIMIT 1;") - if len(rows) > 1: - raise LoadClientSchemaVersionCorrupted(self.sql_client.fully_qualified_dataset_name()) - if len(rows) == 0: - return 0 - return int(rows[0][0]) - - def _update_schema_version(self, new_version: int) -> None: - now_ts = str(pendulum.now()) - name = self.sql_client.make_qualified_table_name(Schema.VERSION_TABLE_NAME) - self.sql_client.execute_sql(f"INSERT INTO {name}({Schema.VERSION_COLUMN_NAME}, engine_version, inserted_at) VALUES ({new_version}, {Schema.ENGINE_VERSION}, '{now_ts}');") diff --git a/dlt/load/configuration.py b/dlt/load/configuration.py index bd5da8935c..42d8cb1209 100644 --- a/dlt/load/configuration.py +++ b/dlt/load/configuration.py @@ -1,32 +1,25 @@ -from typing import Any, Optional, Type -from dlt.common.configuration.run_configuration import BaseConfiguration - -from dlt.common.typing import StrAny -from dlt.common.configuration import (PoolRunnerConfiguration, - LoadVolumeConfiguration, - ProductionLoadVolumeConfiguration, - TPoolType, make_configuration) -from . import __version__ - - -class LoaderClientConfiguration(BaseConfiguration): - CLIENT_TYPE: str = None # which destination to load data to - - -class LoaderClientDwhConfiguration(LoaderClientConfiguration): - DEFAULT_DATASET: str = None # dataset name in the destination to load data to, for schemas that are not default schema, it is used as dataset prefix - DEFAULT_SCHEMA_NAME: Optional[str] = None # name of default schema to be used to name effective dataset to load data to - - -class LoaderConfiguration(PoolRunnerConfiguration, LoadVolumeConfiguration, LoaderClientConfiguration): - WORKERS: int = 20 # how many parallel loads can be executed - # MAX_PARALLELISM: int = 20 # in 20 separate threads - POOL_TYPE: TPoolType = "thread" # mostly i/o (upload) so may be thread pool - - -class ProductionLoaderConfiguration(ProductionLoadVolumeConfiguration, LoaderConfiguration): - pass - - -def configuration(initial_values: StrAny = None) -> Type[LoaderConfiguration]: - return make_configuration(LoaderConfiguration, ProductionLoaderConfiguration, initial_values=initial_values) +from typing import TYPE_CHECKING + +from dlt.common.configuration import configspec +from dlt.common.configuration.specs import BaseConfiguration, PoolRunnerConfiguration, CredentialsConfiguration, TPoolType +from dlt.common.configuration.specs.load_volume_configuration import LoadVolumeConfiguration + + +@configspec(init=True) +class LoaderConfiguration(PoolRunnerConfiguration): + workers: int = 20 # how many parallel loads can be executed + pool_type: TPoolType = "thread" # mostly i/o (upload) so may be thread pool + always_wipe_storage: bool = False # removes all data in the storage + _load_storage_config: LoadVolumeConfiguration = None + + if TYPE_CHECKING: + def __init__( + self, + pool_type: TPoolType = None, + workers: int = None, + exit_on_exception: bool = None, + is_single_run: bool = None, + always_wipe_storage: bool = None, + load_storage_config: LoadVolumeConfiguration = None + ) -> None: + ... diff --git a/dlt/load/dummy/__init__.py b/dlt/load/dummy/__init__.py index e69de29bb2..b29ba69807 100644 --- a/dlt/load/dummy/__init__.py +++ b/dlt/load/dummy/__init__.py @@ -0,0 +1,40 @@ +from typing import Type + +from dlt.common.schema.schema import Schema +from dlt.common.typing import ConfigValue +from dlt.common.configuration import with_config +from dlt.common.destination import DestinationCapabilitiesContext, JobClientBase, DestinationClientConfiguration + +from dlt.load.dummy.configuration import DummyClientConfiguration + + +@with_config(spec=DummyClientConfiguration, namespaces=("destination", "dummy",)) +def _configure(config: DummyClientConfiguration = ConfigValue) -> DummyClientConfiguration: + return config + + +def capabilities() -> DestinationCapabilitiesContext: + config = _configure() + caps = DestinationCapabilitiesContext() + caps.update({ + "preferred_loader_file_format": config.loader_file_format, + "supported_loader_file_formats": [config.loader_file_format], + "max_identifier_length": 127, + "max_column_length": 127, + "max_query_length": 8 * 1024 * 1024, + "is_max_query_length_in_bytes": True, + "max_text_data_type_length": 65535, + "is_max_text_data_type_length_in_bytes": True + }) + return caps + + +def client(schema: Schema, initial_config: DestinationClientConfiguration = ConfigValue) -> JobClientBase: + # import client when creating instance so capabilities and config specs can be accessed without dependencies installed + from dlt.load.dummy.dummy import DummyClient + + return DummyClient(schema, _configure(initial_config)) # type: ignore + + +def spec() -> Type[DestinationClientConfiguration]: + return DummyClientConfiguration diff --git a/dlt/load/dummy/configuration.py b/dlt/load/dummy/configuration.py index 79c414fd50..f39180e1e4 100644 --- a/dlt/load/dummy/configuration.py +++ b/dlt/load/dummy/configuration.py @@ -1,20 +1,12 @@ -from typing import Type - -from dlt.common.typing import StrAny -from dlt.common.configuration import make_configuration -from dlt.common.dataset_writers import TLoaderFileFormat - -from dlt.load.configuration import LoaderClientConfiguration - - -class DummyClientConfiguration(LoaderClientConfiguration): - CLIENT_TYPE: str = "dummy" - LOADER_FILE_FORMAT: TLoaderFileFormat = "jsonl" - FAIL_PROB: float = 0.0 - RETRY_PROB: float = 0.0 - COMPLETED_PROB: float = 0.0 - TIMEOUT: float = 10.0 - - -def configuration(initial_values: StrAny = None) -> Type[DummyClientConfiguration]: - return make_configuration(DummyClientConfiguration, DummyClientConfiguration, initial_values=initial_values) +from dlt.common.configuration import configspec +from dlt.common.destination import DestinationClientConfiguration, TLoaderFileFormat + + +@configspec(init=True) +class DummyClientConfiguration(DestinationClientConfiguration): + destination_name: str = "dummy" + loader_file_format: TLoaderFileFormat = "jsonl" + fail_prob: float = 0.0 + retry_prob: float = 0.0 + completed_prob: float = 0.0 + timeout: float = 10.0 diff --git a/dlt/load/dummy/client.py b/dlt/load/dummy/dummy.py similarity index 69% rename from dlt/load/dummy/client.py rename to dlt/load/dummy/dummy.py index 7e356e7a40..52d610b7a8 100644 --- a/dlt/load/dummy/client.py +++ b/dlt/load/dummy/dummy.py @@ -1,20 +1,18 @@ import random from types import TracebackType -from typing import Dict, Tuple, Type -from dlt.common.dataset_writers import TLoaderFileFormat +from typing import Dict, Type from dlt.common import pendulum from dlt.common.schema import Schema +from dlt.common.storages import FileStorage from dlt.common.schema.typing import TTableSchema -from dlt.common.configuration import CredentialsConfiguration -from dlt.common.typing import StrAny +from dlt.common.destination import DestinationCapabilitiesContext, JobClientBase, LoadJob, TLoadJobStatus -from dlt.load.client_base import JobClientBase, LoadJob, TLoaderCapabilities -from dlt.load.typing import LoadJobStatus from dlt.load.exceptions import (LoadJobNotExistsException, LoadJobInvalidStateTransitionException, LoadClientTerminalException, LoadClientTransientException) -from dlt.load.dummy.configuration import DummyClientConfiguration, configuration +from dlt.load.dummy import capabilities +from dlt.load.dummy.configuration import DummyClientConfiguration class LoadDummyJob(LoadJob): @@ -23,7 +21,7 @@ def __init__(self, file_name: str, fail_prob: float = 0.0, retry_prob: float = 0 self.retry_prob = retry_prob self.completed_prob = completed_prob self.timeout = timeout - self._status: LoadJobStatus = "running" + self._status: TLoadJobStatus = "running" self._exception: str = None self.start_time: float = pendulum.now().timestamp() super().__init__(file_name) @@ -34,7 +32,7 @@ def __init__(self, file_name: str, fail_prob: float = 0.0, retry_prob: float = 0 raise LoadClientTransientException(self._exception) - def status(self) -> LoadJobStatus: + def status(self) -> TLoadJobStatus: # this should poll the server for a job status, here we simulate various outcomes if self._status == "running": n = pendulum.now().timestamp() @@ -78,20 +76,20 @@ class DummyClient(JobClientBase): """ dummy client storing jobs in memory """ - CONFIG: Type[DummyClientConfiguration] = None - def __init__(self, schema: Schema) -> None: - super().__init__(schema) + def __init__(self, schema: Schema, config: DummyClientConfiguration) -> None: + super().__init__(schema, config) + self.config: DummyClientConfiguration = config - def initialize_storage(self) -> None: + def initialize_storage(self, wipe_data: bool = False) -> None: pass def update_storage_schema(self) -> None: pass def start_file_load(self, table: TTableSchema, file_path: str) -> LoadJob: - job_id = JobClientBase.get_file_name_from_file_path(file_path) - file_name = JobClientBase.get_file_name_from_file_path(file_path) + job_id = FileStorage.get_file_name_from_file_path(file_path) + file_name = FileStorage.get_file_name_from_file_path(file_path) # return existing job if already there if job_id not in JOBS: JOBS[job_id] = self._create_job(file_name) @@ -103,7 +101,7 @@ def start_file_load(self, table: TTableSchema, file_path: str) -> LoadJob: return JOBS[job_id] def restore_file_load(self, file_path: str) -> LoadJob: - job_id = JobClientBase.get_file_name_from_file_path(file_path) + job_id = FileStorage.get_file_name_from_file_path(file_path) if job_id not in JOBS: raise LoadJobNotExistsException(job_id) return JOBS[job_id] @@ -120,25 +118,12 @@ def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb def _create_job(self, job_id: str) -> LoadDummyJob: return LoadDummyJob( job_id, - fail_prob=self.CONFIG.FAIL_PROB, - retry_prob=self.CONFIG.RETRY_PROB, - completed_prob=self.CONFIG.COMPLETED_PROB, - timeout=self.CONFIG.TIMEOUT + fail_prob=self.config.fail_prob, + retry_prob=self.config.retry_prob, + completed_prob=self.config.completed_prob, + timeout=self.config.timeout ) @classmethod - def capabilities(cls) -> TLoaderCapabilities: - return { - "preferred_loader_file_format": cls.CONFIG.LOADER_FILE_FORMAT, - "supported_loader_file_formats": [cls.CONFIG.LOADER_FILE_FORMAT], - "max_identifier_length": 127, - "max_column_length": 127 - } - - @classmethod - def configure(cls, initial_values: StrAny = None) -> Tuple[Type[DummyClientConfiguration], Type[CredentialsConfiguration]]: - cls.CONFIG = configuration(initial_values=initial_values) - return cls.CONFIG, None - - -CLIENT = DummyClient + def capabilities(cls) -> DestinationCapabilitiesContext: + return capabilities() diff --git a/dlt/load/exceptions.py b/dlt/load/exceptions.py index 6f7bb8a1d0..60f4b8008d 100644 --- a/dlt/load/exceptions.py +++ b/dlt/load/exceptions.py @@ -1,7 +1,6 @@ from typing import Sequence from dlt.common.exceptions import DltException, TerminalException, TransientException - -from dlt.load.typing import LoadJobStatus +from dlt.common.destination import TLoadJobStatus class LoadException(DltException): @@ -11,12 +10,12 @@ def __init__(self, msg: str) -> None: class LoadClientTerminalException(LoadException, TerminalException): def __init__(self, msg: str) -> None: - super().__init__(msg) + pass class LoadClientTransientException(LoadException, TransientException): def __init__(self, msg: str) -> None: - super().__init__(msg) + pass class LoadClientTerminalInnerException(LoadClientTerminalException): @@ -44,7 +43,7 @@ def __init__(self, table_name: str, file_name: str) -> None: class LoadJobInvalidStateTransitionException(LoadClientTerminalException): - def __init__(self, from_state: LoadJobStatus, to_state: LoadJobStatus) -> None: + def __init__(self, from_state: TLoadJobStatus, to_state: TLoadJobStatus) -> None: self.from_state = from_state self.to_state = to_state super().__init__(f"Load job cannot transition form {from_state} to {to_state}") diff --git a/dlt/load/job_client_impl.py b/dlt/load/job_client_impl.py new file mode 100644 index 0000000000..5085c1f368 --- /dev/null +++ b/dlt/load/job_client_impl.py @@ -0,0 +1,81 @@ +from abc import abstractmethod +from types import TracebackType +from typing import List, Sequence, Type + +from dlt.common import pendulum, logger +from dlt.common.storages import FileStorage +from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns +from dlt.common.destination import DestinationClientConfiguration, TLoadJobStatus, LoadJob, JobClientBase + +from dlt.load.typing import TNativeConn +from dlt.load.sql_client import SqlClientBase +from dlt.load.exceptions import LoadClientSchemaVersionCorrupted + + +class LoadEmptyJob(LoadJob): + def __init__(self, file_name: str, status: TLoadJobStatus, exception: str = None) -> None: + self._status = status + self._exception = exception + super().__init__(file_name) + + @classmethod + def from_file_path(cls, file_path: str, status: TLoadJobStatus, message: str = None) -> "LoadEmptyJob": + return cls(FileStorage.get_file_name_from_file_path(file_path), status, exception=message) + + def status(self) -> TLoadJobStatus: + return self._status + + def file_name(self) -> str: + return self._file_name + + def exception(self) -> str: + return self._exception + + +class SqlJobClientBase(JobClientBase): + def __init__(self, schema: Schema, config: DestinationClientConfiguration, sql_client: SqlClientBase[TNativeConn]) -> None: + super().__init__(schema, config) + self.sql_client = sql_client + + def update_storage_schema(self) -> None: + storage_version = self._get_schema_version_from_storage() + if storage_version < self.schema.stored_version: + for sql in self._build_schema_update_sql(): + self.sql_client.execute_sql(sql) + self._update_schema_version(self.schema.stored_version) + + def complete_load(self, load_id: str) -> None: + name = self.sql_client.make_qualified_table_name(Schema.LOADS_TABLE_NAME) + now_ts = str(pendulum.now()) + self.sql_client.execute_sql(f"INSERT INTO {name}(load_id, status, inserted_at) VALUES('{load_id}', 0, '{now_ts}');") + + def __enter__(self) -> "SqlJobClientBase": + self.sql_client.open_connection() + return self + + def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType) -> None: + self.sql_client.close_connection() + + @abstractmethod + def _build_schema_update_sql(self) -> List[str]: + pass + + def _create_table_update(self, table_name: str, storage_table: TTableSchemaColumns) -> Sequence[TColumnSchema]: + # compare table with stored schema and produce delta + updates = self.schema.get_new_columns(table_name, storage_table) + logger.info(f"Found {len(updates)} updates for {table_name} in {self.schema.name}") + return updates + + def _get_schema_version_from_storage(self) -> int: + name = self.sql_client.make_qualified_table_name(Schema.VERSION_TABLE_NAME) + rows = self.sql_client.execute_sql(f"SELECT {Schema.VERSION_COLUMN_NAME} FROM {name} ORDER BY inserted_at DESC LIMIT 1;") + if len(rows) > 1: + raise LoadClientSchemaVersionCorrupted(self.sql_client.fully_qualified_dataset_name()) + if len(rows) == 0: + return 0 + return int(rows[0][0]) + + def _update_schema_version(self, new_version: int) -> None: + now_ts = str(pendulum.now()) + name = self.sql_client.make_qualified_table_name(Schema.VERSION_TABLE_NAME) + self.sql_client.execute_sql(f"INSERT INTO {name}({Schema.VERSION_COLUMN_NAME}, engine_version, inserted_at) VALUES ({new_version}, {Schema.ENGINE_VERSION}, '{now_ts}');") diff --git a/dlt/load/load.py b/dlt/load/load.py index 7b5516c9cf..dbad1ce523 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -1,27 +1,22 @@ -from typing import List, Optional, Tuple, Type, Protocol +from typing import List, Optional, Tuple from multiprocessing.pool import ThreadPool -from importlib import import_module from prometheus_client import REGISTRY, Counter, Gauge, CollectorRegistry, Summary from dlt.common import sleep, logger -from dlt.cli import TRunnerArgs -from dlt.common.runners import TRunMetrics, initialize_runner, run_pool, Runnable, workermethod +from dlt.common.configuration import with_config +from dlt.common.typing import ConfigValue +from dlt.common.runners import TRunMetrics, Runnable, workermethod from dlt.common.logger import pretty_format_exception from dlt.common.exceptions import TerminalValueError from dlt.common.schema import Schema from dlt.common.schema.typing import TTableSchema -from dlt.common.storages.load_storage import LoadStorage +from dlt.common.storages import LoadStorage from dlt.common.telemetry import get_logging_extras, set_gauge_all_labels -from dlt.common.typing import StrAny +from dlt.common.destination import JobClientBase, DestinationReference, LoadJob, TLoadJobStatus, DestinationClientConfiguration +from dlt.load.job_client_impl import LoadEmptyJob +from dlt.load.configuration import LoaderConfiguration from dlt.load.exceptions import LoadClientTerminalException, LoadClientTransientException, LoadClientUnsupportedWriteDisposition, LoadClientUnsupportedFileFormats, LoadJobNotExistsException, LoadUnknownTableException -from dlt.load.client_base import JobClientBase, LoadJob -from dlt.load.typing import LoadJobStatus, TLoaderCapabilities -from dlt.load.configuration import configuration, LoaderConfiguration - - -class SupportsLoadClient(Protocol): - CLIENT: Type[JobClientBase] class Load(Runnable[ThreadPool]): @@ -31,35 +26,34 @@ class Load(Runnable[ThreadPool]): job_counter: Counter = None job_wait_summary: Summary = None - def __init__(self, C: Type[LoaderConfiguration], collector: CollectorRegistry, client_initial_values: StrAny = None, is_storage_owner: bool = False) -> None: - self.CONFIG = C - self.load_client_cls = self.import_client_cls(C.CLIENT_TYPE, initial_values=client_initial_values) + @with_config(spec=LoaderConfiguration, namespaces=("load",)) + def __init__( + self, + destination: DestinationReference, + collector: CollectorRegistry = REGISTRY, + is_storage_owner: bool = False, + config: LoaderConfiguration = ConfigValue, + initial_client_config: DestinationClientConfiguration = ConfigValue + ) -> None: + self.config = config + self.initial_client_config = initial_client_config + self.destination = destination + self.capabilities = destination.capabilities() self.pool: ThreadPool = None self.load_storage: LoadStorage = self.create_storage(is_storage_owner) try: Load.create_gauges(collector) except ValueError as v: # ignore re-creation of gauges - if "Duplicated timeseries" not in str(v): + if "Duplicated" not in str(v): raise - @staticmethod - def loader_capabilities(client_type: str) -> TLoaderCapabilities: - m: SupportsLoadClient = import_module(f"dlt.load.{client_type}.client") - return m.CLIENT.capabilities() - - @staticmethod - def import_client_cls(client_type: str, initial_values: StrAny = None) -> Type[JobClientBase]: - m: SupportsLoadClient = import_module(f"dlt.load.{client_type}.client") - m.CLIENT.configure(initial_values) - return m.CLIENT - def create_storage(self, is_storage_owner: bool) -> LoadStorage: load_storage = LoadStorage( is_storage_owner, - self.CONFIG, - self.load_client_cls.capabilities()["preferred_loader_file_format"], - self.load_client_cls.capabilities()["supported_loader_file_formats"] + self.capabilities.preferred_loader_file_format, + self.capabilities.supported_loader_file_formats, + config=self.config._load_storage_config ) return load_storage @@ -87,19 +81,19 @@ def w_spool_job(self: "Load", file_path: str, load_id: str, schema: Schema) -> O # open new connection for each upload job: LoadJob = None try: - with self.load_client_cls(schema) as client: + with self.destination.client(schema, self.initial_client_config) as client: job_info = self.load_storage.parse_job_file_name(file_path) - if job_info.file_format not in client.capabilities()["supported_loader_file_formats"]: - raise LoadClientUnsupportedFileFormats(job_info.file_format, client.capabilities()["supported_loader_file_formats"], file_path) + if job_info.file_format not in self.capabilities.supported_loader_file_formats: + raise LoadClientUnsupportedFileFormats(job_info.file_format, self.capabilities.supported_loader_file_formats, file_path) logger.info(f"Will load file {file_path} with table name {job_info.table_name}") table = self.get_load_table(schema, job_info.table_name, file_path) if table["write_disposition"] not in ["append", "replace"]: raise LoadClientUnsupportedWriteDisposition(job_info.table_name, table["write_disposition"], file_path) - job = client.start_file_load(table, self.load_storage.storage._make_path(file_path)) + job = client.start_file_load(table, self.load_storage.storage.make_full_path(file_path)) except (LoadClientTerminalException, TerminalValueError): # if job irreversibly cannot be started, mark it as failed logger.exception(f"Terminal problem with spooling job {file_path}") - job = JobClientBase.make_job_with_status(file_path, "failed", pretty_format_exception()) + job = LoadEmptyJob.from_file_path(file_path, "failed", pretty_format_exception()) except (LoadClientTransientException, Exception): # return no job so file stays in new jobs (root) folder logger.exception(f"Temporary problem with spooling job {file_path}") @@ -113,7 +107,7 @@ def spool_new_jobs(self, load_id: str, schema: Schema) -> Tuple[int, List[LoadJo # use thread based pool as jobs processing is mostly I/O and we do not want to pickle jobs # TODO: combine files by providing a list of files pertaining to same table into job, so job must be # extended to accept a list - load_files = self.load_storage.list_new_jobs(load_id)[:self.CONFIG.WORKERS] + load_files = self.load_storage.list_new_jobs(load_id)[:self.config.workers] file_count = len(load_files) if file_count == 0: logger.info(f"No new jobs found in {load_id}") @@ -141,7 +135,7 @@ def retrieve_jobs(self, client: JobClientBase, load_id: str) -> Tuple[int, List[ job = client.restore_file_load(file_path) except LoadClientTerminalException: logger.exception(f"Job retrieval for {file_path} failed, job will be terminated") - job = JobClientBase.make_job_with_status(file_path, "failed", pretty_format_exception()) + job = LoadEmptyJob.from_file_path(file_path, "failed", pretty_format_exception()) # proceed to appending job, do not reraise except (LoadClientTransientException, Exception): # raise on all temporary exceptions, typically network / server problems @@ -161,7 +155,7 @@ def complete_jobs(self, load_id: str, jobs: List[LoadJob]) -> List[LoadJob]: for ii in range(len(jobs)): job = jobs[ii] logger.debug(f"Checking status for job {job.file_name()}") - status: LoadJobStatus = job.status() + status: TLoadJobStatus = job.status() final_location: str = None if status == "running": # ask again @@ -207,12 +201,12 @@ def run(self, pool: ThreadPool) -> TRunMetrics: schema = self.load_storage.load_package_schema(load_id) logger.info(f"Loaded schema name {schema.name} and version {schema.stored_version}") # initialize analytical storage ie. create dataset required by passed schema - with self.load_client_cls(schema) as client: - logger.info(f"Client {self.CONFIG.CLIENT_TYPE} will start load") + with self.destination.client(schema, self.initial_client_config) as client: + logger.info(f"Client for {client.config.destination_name} will start load") client.initialize_storage() schema_update = self.load_storage.begin_schema_update(load_id) if schema_update: - logger.info(f"Client {self.CONFIG.CLIENT_TYPE} will update schema to package schema") + logger.info(f"Client for {client.config.destination_name} will update schema to package schema") # TODO: this should rather generate an SQL job(s) to be executed PRE loading client.update_storage_schema() self.load_storage.commit_schema_update(load_id) @@ -231,7 +225,7 @@ def run(self, pool: ThreadPool) -> TRunMetrics: ) # if there are no existing or new jobs we complete the package if jobs_count == 0: - with self.load_client_cls(schema) as client: + with self.destination.client(schema, self.initial_client_config) as client: # TODO: this script should be executed as a job (and contain also code to merge/upsert data and drop temp tables) # TODO: post loading jobs remaining_jobs = client.complete_load(load_id) @@ -240,6 +234,7 @@ def run(self, pool: ThreadPool) -> TRunMetrics: self.load_counter.inc() logger.metrics("Load package metrics", extra=get_logging_extras([self.load_counter])) else: + # TODO: this loop must be urgently removed. while True: remaining_jobs = self.complete_jobs(load_id, jobs) if len(remaining_jobs) == 0: @@ -250,18 +245,3 @@ def run(self, pool: ThreadPool) -> TRunMetrics: sleep(1) return TRunMetrics(False, False, len(self.load_storage.list_packages())) - - -def main(args: TRunnerArgs) -> int: - C = configuration(args._asdict()) - initialize_runner(C) - try: - load = Load(C, REGISTRY) - except Exception: - logger.exception("init module") - return -1 - return run_pool(C, load) - - -def run_main(args: TRunnerArgs) -> None: - exit(main(args)) diff --git a/dlt/load/redshift/__init__.py b/dlt/load/redshift/__init__.py index e69de29bb2..45723d5540 100644 --- a/dlt/load/redshift/__init__.py +++ b/dlt/load/redshift/__init__.py @@ -0,0 +1,39 @@ +from typing import Type + +from dlt.common.schema.schema import Schema +from dlt.common.typing import ConfigValue +from dlt.common.configuration import with_config +from dlt.common.destination import DestinationCapabilitiesContext, JobClientBase, DestinationClientConfiguration + +from dlt.load.redshift.configuration import RedshiftClientConfiguration + + +@with_config(spec=RedshiftClientConfiguration, namespaces=("destination", "redshift",)) +def _configure(config: RedshiftClientConfiguration = ConfigValue) -> RedshiftClientConfiguration: + return config + + +def capabilities() -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.update({ + "preferred_loader_file_format": "insert_values", + "supported_loader_file_formats": ["insert_values"], + "max_identifier_length": 127, + "max_column_length": 127, + "max_query_length": 16 * 1024 * 1024, + "is_max_query_length_in_bytes": True, + "max_text_data_type_length": 65535, + "is_max_text_data_type_length_in_bytes": True + }) + return caps + + +def client(schema: Schema, initial_config: DestinationClientConfiguration = ConfigValue) -> JobClientBase: + # import client when creating instance so capabilities and config specs can be accessed without dependencies installed + from dlt.load.redshift.redshift import RedshiftClient + + return RedshiftClient(schema, _configure(initial_config)) # type: ignore + + +def spec() -> Type[DestinationClientConfiguration]: + return RedshiftClientConfiguration diff --git a/dlt/load/redshift/configuration.py b/dlt/load/redshift/configuration.py index cd883ed885..ce724eec4a 100644 --- a/dlt/load/redshift/configuration.py +++ b/dlt/load/redshift/configuration.py @@ -1,17 +1,10 @@ -from typing import Tuple, Type +from dlt.common.configuration import configspec +from dlt.common.configuration.specs import PostgresCredentials +from dlt.common.destination import DestinationClientDwhConfiguration -from dlt.common.typing import StrAny -from dlt.common.configuration import make_configuration, PostgresCredentials -from dlt.load.configuration import LoaderClientDwhConfiguration +@configspec(init=True) +class RedshiftClientConfiguration(DestinationClientDwhConfiguration): + destination_name: str = "redshift" + credentials: PostgresCredentials - -class RedshiftClientConfiguration(LoaderClientDwhConfiguration): - CLIENT_TYPE: str = "redshift" - - -def configuration(initial_values: StrAny = None) -> Tuple[Type[RedshiftClientConfiguration], Type[PostgresCredentials]]: - return ( - make_configuration(RedshiftClientConfiguration, RedshiftClientConfiguration, initial_values=initial_values), - make_configuration(PostgresCredentials, PostgresCredentials, initial_values=initial_values) - ) diff --git a/dlt/load/redshift/client.py b/dlt/load/redshift/redshift.py similarity index 85% rename from dlt/load/redshift/client.py rename to dlt/load/redshift/redshift.py index a7dba902d4..7764460e35 100644 --- a/dlt/load/redshift/client.py +++ b/dlt/load/redshift/redshift.py @@ -1,5 +1,4 @@ import platform - if platform.python_implementation() == "PyPy": import psycopg2cffi as psycopg2 from psycopg2cffi.sql import SQL, Identifier, Composed, Literal as SQLLiteral @@ -9,21 +8,23 @@ from contextlib import contextmanager -from typing import Any, AnyStr, Dict, Iterator, List, Optional, Sequence, Tuple, Type -from dlt.common.configuration.postgres_credentials import PostgresCredentials +from typing import Any, AnyStr, Dict, Iterator, List, Optional, Sequence, Tuple -from dlt.common.typing import StrAny from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE -from dlt.common.dataset_writers import escape_redshift_identifier -from dlt.common.schema import COLUMN_HINTS, TColumnSchema, TColumnSchemaBase, TDataType, THintType, Schema, TTableSchemaColumns, add_missing_hints +from dlt.common.configuration.specs import PostgresCredentials +from dlt.common.destination import DestinationCapabilitiesContext, LoadJob, TLoadJobStatus +from dlt.common.data_writers import escape_redshift_identifier +from dlt.common.schema import COLUMN_HINTS, TColumnSchema, TColumnSchemaBase, TDataType, TColumnHint, Schema, TTableSchemaColumns, add_missing_hints from dlt.common.schema.typing import TTableSchema, TWriteDisposition +from dlt.common.storages.file_storage import FileStorage -from dlt.load.exceptions import (LoadClientSchemaWillNotUpdate, LoadClientTerminalInnerException, - LoadClientTransientInnerException) -from dlt.load.typing import LoadJobStatus, DBCursor, TLoaderCapabilities -from dlt.load.client_base import JobClientBase, SqlClientBase, SqlJobClientBase, LoadJob +from dlt.load.exceptions import LoadClientSchemaWillNotUpdate, LoadClientTerminalInnerException, LoadClientTransientInnerException +from dlt.load.typing import DBCursor +from dlt.load.sql_client import SqlClientBase +from dlt.load.job_client_impl import SqlJobClientBase, LoadEmptyJob -from dlt.load.redshift.configuration import configuration, RedshiftClientConfiguration +from dlt.load.redshift import capabilities +from dlt.load.redshift.configuration import RedshiftClientConfiguration SCT_TO_PGT: Dict[TDataType, str] = { @@ -47,7 +48,7 @@ "numeric": "decimal" } -HINT_TO_REDSHIFT_ATTR: Dict[THintType, str] = { +HINT_TO_REDSHIFT_ATTR: Dict[TColumnHint, str] = { "cluster": "DISTKEY", # it is better to not enforce constraints in redshift # "primary_key": "PRIMARY KEY", @@ -57,14 +58,14 @@ class RedshiftSqlClient(SqlClientBase["psycopg2.connection"]): - def __init__(self, default_dataset_name: str, CREDENTIALS: Type[PostgresCredentials]) -> None: + def __init__(self, default_dataset_name: str, credentials: PostgresCredentials) -> None: super().__init__(default_dataset_name) self._conn: psycopg2.connection = None - self.C = CREDENTIALS + self.credentials = credentials def open_connection(self) -> None: self._conn = psycopg2.connect( - **self.C.as_dict(), + **self.credentials, options=f"-c search_path={self.fully_qualified_dataset_name()},public" ) # we'll provide explicit transactions @@ -139,17 +140,13 @@ def fully_qualified_dataset_name(self) -> str: class RedshiftInsertLoadJob(LoadJob): - - MAX_STATEMENT_SIZE = 8 * 1024 * 1024 - - def __init__(self, table_name: str, write_disposition: TWriteDisposition, file_path: str, sql_client: SqlClientBase["psycopg2.connection"]) -> None: - super().__init__(JobClientBase.get_file_name_from_file_path(file_path)) + super().__init__(FileStorage.get_file_name_from_file_path(file_path)) self._sql_client = sql_client # insert file content immediately self._insert(sql_client.make_qualified_table_name(table_name), write_disposition, file_path) - def status(self) -> LoadJobStatus: + def status(self) -> TLoadJobStatus: # this job is always done return "completed" @@ -174,7 +171,7 @@ def _insert(self, qualified_table_name: str, write_disposition: TWriteDispositio if write_disposition == "replace": insert_sql.append(SQL("DELETE FROM {};").format(SQL(qualified_table_name))) # is_eof = False - while content := f.read(RedshiftInsertLoadJob.MAX_STATEMENT_SIZE): + while content := f.read(RedshiftClient.capabilities()["max_query_length"] // 2): # read one more line in order to # 1. complete the content which ends at "random" position, not an end line # 2. to modify it's ending without a need to re-allocating the 8MB of "content" @@ -204,18 +201,21 @@ def _insert(self, qualified_table_name: str, write_disposition: TWriteDispositio class RedshiftClient(SqlJobClientBase): - CONFIG: Type[RedshiftClientConfiguration] = None - CREDENTIALS: Type[PostgresCredentials] = None + # CONFIG: RedshiftClientConfiguration = None + # CREDENTIALS: PostgresCredentials = None - def __init__(self, schema: Schema) -> None: + def __init__(self, schema: Schema, config: RedshiftClientConfiguration) -> None: sql_client = RedshiftSqlClient( - schema.normalize_make_dataset_name(self.CONFIG.DEFAULT_DATASET, self.CONFIG.DEFAULT_SCHEMA_NAME, schema.name), - self.CREDENTIALS + schema.normalize_make_dataset_name(config.dataset_name, config.default_schema_name, schema.name), + config.credentials ) - super().__init__(schema, sql_client) + super().__init__(schema, config, sql_client) + self.config: RedshiftClientConfiguration = config self.sql_client = sql_client - def initialize_storage(self) -> None: + def initialize_storage(self, wipe_data: bool = False) -> None: + if wipe_data: + raise NotImplementedError() if not self.sql_client.has_dataset(): self.sql_client.create_dataset() @@ -223,7 +223,7 @@ def restore_file_load(self, file_path: str) -> LoadJob: # always returns completed jobs as RedshiftInsertLoadJob is executed # atomically in start_file_load so any jobs that should be recreated are already completed # in case of bugs in loader (asking for jobs that were never created) we are not able to detect that - return JobClientBase.make_job_with_status(file_path, "completed") + return LoadEmptyJob.from_file_path(file_path, "completed") def start_file_load(self, table: TTableSchema, file_path: str) -> LoadJob: try: @@ -235,7 +235,7 @@ def start_file_load(self, table: TTableSchema, file_path: str) -> LoadJob: raise LoadClientTerminalInnerException("Terminal error, file will not load", tr_ex) if "Numeric data overflow" in tr_ex.pgerror: raise LoadClientTerminalInnerException("Terminal error, file will not load", tr_ex) - if "Precision exceeds maximum": + if "Precision exceeds maximum" in tr_ex.pgerror: raise LoadClientTerminalInnerException("Terminal error, file will not load", tr_ex) raise LoadClientTransientInnerException("Error may go away, will retry", tr_ex) except (psycopg2.DataError, psycopg2.ProgrammingError, psycopg2.IntegrityError) as ter_ex: @@ -338,18 +338,5 @@ def _pq_t_to_sc_t(pq_t: str, precision: Optional[int], scale: Optional[int]) -> return PGT_TO_SCT.get(pq_t, "text") @classmethod - def capabilities(cls) -> TLoaderCapabilities: - return { - "preferred_loader_file_format": "insert_values", - "supported_loader_file_formats": ["insert_values"], - "max_identifier_length": 127, - "max_column_length": 127 - } - - @classmethod - def configure(cls, initial_values: StrAny = None) -> Tuple[Type[RedshiftClientConfiguration], Type[PostgresCredentials]]: - cls.CONFIG, cls.CREDENTIALS = configuration(initial_values=initial_values) - return cls.CONFIG, cls.CREDENTIALS - - -CLIENT = RedshiftClient + def capabilities(cls) -> DestinationCapabilitiesContext: + return capabilities() diff --git a/dlt/load/sql_client.py b/dlt/load/sql_client.py new file mode 100644 index 0000000000..9ede7e6483 --- /dev/null +++ b/dlt/load/sql_client.py @@ -0,0 +1,69 @@ +from abc import ABC, abstractmethod +from contextlib import contextmanager +from types import TracebackType +from typing import Any, ContextManager, Generic, Iterator, Optional, Sequence, Tuple, Type, AnyStr, Protocol + +from dlt.load.typing import TNativeConn, DBCursor + + +class SqlClientBase(ABC, Generic[TNativeConn]): + def __init__(self, default_dataset_name: str) -> None: + if not default_dataset_name: + raise ValueError(default_dataset_name) + self.default_dataset_name = default_dataset_name + + @abstractmethod + def open_connection(self) -> None: + pass + + @abstractmethod + def close_connection(self) -> None: + pass + + def __enter__(self) -> "SqlClientBase[TNativeConn]": + self.open_connection() + return self + + def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType) -> None: + self.close_connection() + + @abstractmethod + def native_connection(self) -> TNativeConn: + pass + + @abstractmethod + def has_dataset(self) -> bool: + pass + + @abstractmethod + def create_dataset(self) -> None: + pass + + @abstractmethod + def drop_dataset(self) -> None: + pass + + @abstractmethod + def execute_sql(self, sql: AnyStr, *args: Any, **kwargs: Any) -> Optional[Sequence[Sequence[Any]]]: + pass + + @abstractmethod + def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> ContextManager[DBCursor]: + pass + + @abstractmethod + def fully_qualified_dataset_name(self) -> str: + pass + + def make_qualified_table_name(self, table_name: str) -> str: + return f"{self.fully_qualified_dataset_name()}.{table_name}" + + @contextmanager + def with_alternative_dataset_name(self, dataset_name: str) -> Iterator["SqlClientBase[TNativeConn]"]: + current_dataset_name = self.default_dataset_name + try: + self.default_dataset_name = dataset_name + yield self + finally: + # restore previous dataset name + self.default_dataset_name = current_dataset_name diff --git a/dlt/load/typing.py b/dlt/load/typing.py index 62ff129886..f576ae5c97 100644 --- a/dlt/load/typing.py +++ b/dlt/load/typing.py @@ -1,20 +1,9 @@ -from typing import Any, AnyStr, List, Literal, Optional, Tuple, TypeVar, TypedDict +from typing import Any, AnyStr, List, Literal, Optional, Tuple, TypeVar -from dlt.common.dataset_writers import TLoaderFileFormat - - -LoadJobStatus = Literal["running", "failed", "retry", "completed"] # native connection TNativeConn = TypeVar("TNativeConn", bound="object") -class TLoaderCapabilities(TypedDict): - preferred_loader_file_format: TLoaderFileFormat - supported_loader_file_formats: List[TLoaderFileFormat] - max_identifier_length: int - max_column_length: int - - # type for dbapi cursor class DBCursor: closed: Any diff --git a/dlt/normalize/__init__.py b/dlt/normalize/__init__.py index a55a9257f8..a40a5eaa7e 100644 --- a/dlt/normalize/__init__.py +++ b/dlt/normalize/__init__.py @@ -1,2 +1 @@ -from dlt._version import normalize_version as __version__ -from .normalize import Normalize, configuration \ No newline at end of file +from .normalize import Normalize \ No newline at end of file diff --git a/dlt/normalize/configuration.py b/dlt/normalize/configuration.py index f5090a1f28..6520a25c57 100644 --- a/dlt/normalize/configuration.py +++ b/dlt/normalize/configuration.py @@ -1,26 +1,27 @@ -from typing import Type - -from dlt.common.typing import StrAny -from dlt.common.dataset_writers import TLoaderFileFormat -from dlt.common.configuration import (PoolRunnerConfiguration, NormalizeVolumeConfiguration, - LoadVolumeConfiguration, SchemaVolumeConfiguration, - ProductionLoadVolumeConfiguration, ProductionNormalizeVolumeConfiguration, - ProductionSchemaVolumeConfiguration, - TPoolType, make_configuration) - -from . import __version__ - - -class NormalizeConfiguration(PoolRunnerConfiguration, NormalizeVolumeConfiguration, LoadVolumeConfiguration, SchemaVolumeConfiguration): - MAX_EVENTS_IN_CHUNK: int = 40000 # maximum events to be processed in single chunk - LOADER_FILE_FORMAT: TLoaderFileFormat = "jsonl" # jsonp or insert commands will be generated - POOL_TYPE: TPoolType = "process" - - -class ProductionNormalizeConfiguration(ProductionNormalizeVolumeConfiguration, ProductionLoadVolumeConfiguration, - ProductionSchemaVolumeConfiguration, NormalizeConfiguration): - pass - - -def configuration(initial_values: StrAny = None) -> Type[NormalizeConfiguration]: - return make_configuration(NormalizeConfiguration, ProductionNormalizeConfiguration, initial_values=initial_values) +from typing import TYPE_CHECKING + +from dlt.common.configuration import configspec +from dlt.common.configuration.specs import LoadVolumeConfiguration, NormalizeVolumeConfiguration, SchemaVolumeConfiguration, PoolRunnerConfiguration, TPoolType +from dlt.common.destination import DestinationCapabilitiesContext + + +@configspec(init=True) +class NormalizeConfiguration(PoolRunnerConfiguration): + pool_type: TPoolType = "process" + destination_capabilities: DestinationCapabilitiesContext = None # injectable + _schema_storage_config: SchemaVolumeConfiguration + _normalize_storage_config: NormalizeVolumeConfiguration + _load_storage_config: LoadVolumeConfiguration + + if TYPE_CHECKING: + def __init__( + self, + pool_type: TPoolType = None, + workers: int = None, + exit_on_exception: bool = None, + is_single_run: bool = None, + schema_storage_config: SchemaVolumeConfiguration = None, + normalize_storage_config: NormalizeVolumeConfiguration = None, + load_storage_config: LoadVolumeConfiguration = None + ) -> None: + ... diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index f387229d05..4d3ee8a4ba 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -4,36 +4,41 @@ from prometheus_client import Counter, CollectorRegistry, REGISTRY, Gauge from dlt.common import pendulum, signals, json, logger +from dlt.common.configuration import with_config +from dlt.common.configuration.specs.load_volume_configuration import LoadVolumeConfiguration +from dlt.common.configuration.specs.normalize_volume_configuration import NormalizeVolumeConfiguration +from dlt.common.data_writers.writers import TLoaderFileFormat from dlt.common.json import custom_pua_decode -from dlt.cli import TRunnerArgs -from dlt.common.runners import TRunMetrics, Runnable, run_pool, initialize_runner, workermethod +from dlt.common.runners import TRunMetrics, Runnable +from dlt.common.schema.typing import TStoredSchema, TTableSchemaColumns from dlt.common.storages.exceptions import SchemaNotFoundError -from dlt.common.storages.normalize_storage import NormalizeStorage +from dlt.common.storages import NormalizeStorage, SchemaStorage, LoadStorage from dlt.common.telemetry import get_logging_extras -from dlt.common.utils import uniq_id -from dlt.common.typing import TDataItem +from dlt.common.typing import ConfigValue, StrAny, TDataItem from dlt.common.exceptions import PoolException -from dlt.common.storages import SchemaStorage from dlt.common.schema import TSchemaUpdate, Schema from dlt.common.schema.exceptions import CannotCoerceColumnException -from dlt.common.storages.load_storage import LoadStorage -from dlt.normalize.configuration import configuration, NormalizeConfiguration +from dlt.normalize.configuration import NormalizeConfiguration -TMapFuncRV = Tuple[List[TSchemaUpdate], List[Sequence[str]]] -TMapFuncType = Callable[[str, str, Sequence[str]], TMapFuncRV] +# normalize worker wrapping function (map_parallel, map_single) return type +TMapFuncRV = Tuple[int, List[TSchemaUpdate], List[Sequence[str]]] # (total items processed, list of schema updates, list of processed files) +# normalize worker wrapping function signature +TMapFuncType = Callable[[Schema, str, Sequence[str]], TMapFuncRV] # input parameters: (schema name, load_id, list of files to process) class Normalize(Runnable[ProcessPool]): # make our gauges static - event_counter: Counter = None - event_gauge: Gauge = None + item_counter: Counter = None + item_gauge: Gauge = None schema_version_gauge: Gauge = None load_package_counter: Counter = None - def __init__(self, C: Type[NormalizeConfiguration], collector: CollectorRegistry = REGISTRY, schema_storage: SchemaStorage = None) -> None: - self.CONFIG = C + @with_config(spec=NormalizeConfiguration, namespaces=("normalize",)) + def __init__(self, collector: CollectorRegistry = REGISTRY, schema_storage: SchemaStorage = None, config: NormalizeConfiguration = ConfigValue) -> None: + self.config = config + self.loader_file_format = config.destination_capabilities.preferred_loader_file_format self.pool: ProcessPool = None self.normalize_storage: NormalizeStorage = None self.load_storage: LoadStorage = None @@ -42,29 +47,32 @@ def __init__(self, C: Type[NormalizeConfiguration], collector: CollectorRegistry # setup storages self.create_storages() # create schema storage with give type - self.schema_storage = schema_storage or SchemaStorage(self.CONFIG, makedirs=True) + self.schema_storage = schema_storage or SchemaStorage(self.config._schema_storage_config, makedirs=True) try: self.create_gauges(collector) except ValueError as v: # ignore re-creation of gauges - if "Duplicated timeseries" not in str(v): + if "Duplicated" not in str(v): raise @staticmethod def create_gauges(registry: CollectorRegistry) -> None: - Normalize.event_counter = Counter("normalize_event_count", "Events processed in normalize", ["schema"], registry=registry) - Normalize.event_gauge = Gauge("normalize_last_events", "Number of events processed in last run", ["schema"], registry=registry) + Normalize.item_counter = Counter("normalize_item_count", "Items processed in normalize", ["schema"], registry=registry) + Normalize.item_gauge = Gauge("normalize_last_items", "Number of items processed in last run", ["schema"], registry=registry) Normalize.schema_version_gauge = Gauge("normalize_schema_version", "Current schema version", ["schema"], registry=registry) Normalize.load_package_counter = Gauge("normalize_load_packages_created_count", "Count of load package created", ["schema"], registry=registry) def create_storages(self) -> None: - self.normalize_storage = NormalizeStorage(True, self.CONFIG) + # pass initial normalize storage config embedded in normalize config + self.normalize_storage = NormalizeStorage(True, config=self.config._normalize_storage_config) # normalize saves in preferred format but can read all supported formats - self.load_storage = LoadStorage(True, self.CONFIG, self.CONFIG.LOADER_FILE_FORMAT, LoadStorage.ALL_SUPPORTED_FILE_FORMATS) + self.load_storage = LoadStorage(True, self.loader_file_format, LoadStorage.ALL_SUPPORTED_FILE_FORMATS, config=self.config._load_storage_config) - def load_or_create_schema(self, schema_name: str) -> Schema: + + @staticmethod + def load_or_create_schema(schema_storage: SchemaStorage, schema_name: str) -> Schema: try: - schema = self.schema_storage.load_schema(schema_name) + schema = schema_storage.load_schema(schema_name) logger.info(f"Loaded schema with name {schema_name} with version {schema.stored_version}") except SchemaNotFoundError: schema = Schema(schema_name) @@ -72,74 +80,102 @@ def load_or_create_schema(self, schema_name: str) -> Schema: return schema @staticmethod - @workermethod - def w_normalize_files(self: "Normalize", schema_name: str, load_id: str, events_files: Sequence[str]) -> TSchemaUpdate: - normalized_data: Dict[str, List[Any]] = {} + def w_normalize_files( + normalize_storage_config: NormalizeVolumeConfiguration, + loader_storage_config: LoadVolumeConfiguration, + loader_file_format: TLoaderFileFormat, + stored_schema: TStoredSchema, + load_id: str, + extracted_items_files: Sequence[str] + ) -> Tuple[TSchemaUpdate, int]: + schema = Schema.from_stored_schema(stored_schema) + load_storage = LoadStorage(False, loader_file_format, LoadStorage.ALL_SUPPORTED_FILE_FORMATS, loader_storage_config) + normalize_storage = NormalizeStorage(False, normalize_storage_config) schema_update: TSchemaUpdate = {} - schema = self.load_or_create_schema(schema_name) - file_id = uniq_id(5) - - # process all event files and store rows in memory - for events_file in events_files: - i: int = 0 - event: TDataItem = None - try: - logger.debug(f"Processing events file {events_file} in load_id {load_id} with file_id {file_id}") - with self.normalize_storage.storage.open_file(events_file) as f: - events: Sequence[TDataItem] = json.load(f) - for i, event in enumerate(events): - for (table_name, parent_table), row in schema.normalize_data_item(schema, event, load_id): - # filter row, may eliminate some or all fields - row = schema.filter_row(table_name, row) - # do not process empty rows - if row: - # decode pua types - for k, v in row.items(): - row[k] = custom_pua_decode(v) # type: ignore - # check if schema can be updated - row, partial_table = schema.coerce_row(table_name, parent_table, row) - if partial_table: - # update schema and save the change - schema.update_schema(partial_table) - table_updates = schema_update.setdefault(table_name, []) - table_updates.append(partial_table) - # store row - rows = normalized_data.setdefault(table_name, []) - rows.append(row) - if i % 100 == 0: - logger.debug(f"Processed {i} of {len(events)} events") - except Exception: - logger.exception(f"Exception when processing file {events_file}, event idx {i}") - logger.debug(f"Affected event: {event}") - raise PoolException("normalize_files", events_file) - - # save rows and return schema changes to be gathered in parent process - for table_name, rows in normalized_data.items(): - # save into new jobs to processed as load - table = schema.get_table_columns(table_name) - self.load_storage.write_temp_job_file(load_id, table_name, table, file_id, rows) - - return schema_update - - def map_parallel(self, schema_name: str, load_id: str, files: Sequence[str]) -> TMapFuncRV: - # we chunk files in a way to not exceed MAX_EVENTS_IN_CHUNK and split them equally - # between processors - configured_processes = self.pool._processes # type: ignore - chunk_files = NormalizeStorage.chunk_by_events(files, self.CONFIG.MAX_EVENTS_IN_CHUNK, configured_processes) - logger.info(f"Obtained {len(chunk_files)} processing chunks") - # use id of self to pass the self instance. see `Runnable` class docstrings - param_chunk = [(id(self), schema_name, load_id, files) for files in chunk_files] - return self.pool.starmap(Normalize.w_normalize_files, param_chunk), chunk_files - - def map_single(self, schema_name: str, load_id: str, files: Sequence[str]) -> TMapFuncRV: - chunk_files = NormalizeStorage.chunk_by_events(files, self.CONFIG.MAX_EVENTS_IN_CHUNK, 1) - # get in one chunk - assert len(chunk_files) == 1 - logger.info(f"Obtained {len(chunk_files)} processing chunks") - # use id of self to pass the self instance. see `Runnable` class docstrings - self_id: Any = id(self) - return [Normalize.w_normalize_files(self_id, schema_name, load_id, chunk_files[0])], chunk_files + total_items = 0 + + # process all files with data items and write to buffered item storage + try: + for extracted_items_file in extracted_items_files: + line_no: int = 0 + root_table_name = NormalizeStorage.parse_normalize_file_name(extracted_items_file).table_name + logger.debug(f"Processing extracted items in {extracted_items_file} in load_id {load_id} with table name {root_table_name} and schema {schema.name}") + with normalize_storage.storage.open_file(extracted_items_file) as f: + # enumerate jsonl file line by line + for line_no, line in enumerate(f): + items: List[TDataItem] = json.loads(line) + partial_update, items_count = Normalize._w_normalize_chunk(load_storage, schema, load_id, root_table_name, items) + schema_update.update(partial_update) + total_items += items_count + logger.debug(f"Processed {line_no} items from file {extracted_items_file}, items {items_count} of total {total_items}") + # if any item found in the file + if items_count > 0: + logger.debug(f"Processed total {line_no + 1} lines from file {extracted_items_file}, total items {total_items}") + except Exception: + logger.exception(f"Exception when processing file {extracted_items_file}, line {line_no}") + # logger.debug(f"Affected item: {item}") + raise PoolException("normalize_files", extracted_items_file) + finally: + load_storage.close_writers(load_id) + + logger.debug(f"Processed total {total_items} items in {len(extracted_items_files)} files") + + return schema_update, total_items + + @staticmethod + def _w_normalize_chunk(load_storage: LoadStorage, schema: Schema, load_id: str, root_table_name: str, items: List[TDataItem]) -> Tuple[TSchemaUpdate, int]: + column_schemas: Dict[str, TTableSchemaColumns] = {} # quick access to column schema for writers below + schema_update: TSchemaUpdate = {} + schema_name = schema.name + items_count = 0 + + for item in items: + for (table_name, parent_table), row in schema.normalize_data_item(schema, item, load_id, root_table_name): + # filter row, may eliminate some or all fields + row = schema.filter_row(table_name, row) + # do not process empty rows + if row: + # decode pua types + for k, v in row.items(): + row[k] = custom_pua_decode(v) # type: ignore + # coerce row of values into schema table, generating partial table with new columns if any + row, partial_table = schema.coerce_row(table_name, parent_table, row) + if partial_table: + # update schema and save the change + schema.update_schema(partial_table) + table_updates = schema_update.setdefault(table_name, []) + table_updates.append(partial_table) + # get current columns schema + columns = column_schemas.get(table_name) + if not columns: + columns = schema.get_table_columns(table_name) + column_schemas[table_name] = columns + # store row + load_storage.write_data_item(load_id, schema_name, table_name, row, columns) + # count total items + items_count += 1 + return schema_update, items_count + + def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TMapFuncRV: + # TODO: maybe we should chunk by file size, now map all files to workers + chunk_files = [files] + schema_dict = schema.to_dict() + config_tuple = (self.normalize_storage.config, self.load_storage.config, self.loader_file_format, schema_dict) + param_chunk = [(*config_tuple, load_id, files) for files in chunk_files] + processed_chunks = self.pool.starmap(Normalize.w_normalize_files, param_chunk) + return sum([t[1] for t in processed_chunks]), [t[0] for t in processed_chunks], chunk_files + + def map_single(self, schema: Schema, load_id: str, files: Sequence[str]) -> TMapFuncRV: + processed_chunk = Normalize.w_normalize_files( + self.normalize_storage.config, + self.load_storage.config, + self.loader_file_format, + schema.to_dict(), + load_id, + files + ) + return processed_chunk[1], [processed_chunk[0]], [files] def update_schema(self, schema: Schema, schema_updates: List[TSchemaUpdate]) -> int: updates_count = 0 @@ -152,10 +188,11 @@ def update_schema(self, schema: Schema, schema_updates: List[TSchemaUpdate]) -> return updates_count def spool_files(self, schema_name: str, load_id: str, map_f: TMapFuncType, files: Sequence[str]) -> None: + schema = Normalize.load_or_create_schema(self.schema_storage, schema_name) + # process files in parallel or in single thread, depending on map_f - schema_updates, chunk_files = map_f(schema_name, load_id, files) + total_items, schema_updates, chunk_files = map_f(schema, load_id, files) - schema = self.load_or_create_schema(schema_name) # gather schema from all manifests, validate consistency and combine updates_count = self.update_schema(schema, schema_updates) self.schema_version_gauge.labels(schema_name).set(schema.version) @@ -173,18 +210,16 @@ def spool_files(self, schema_name: str, load_id: str, map_f: TMapFuncType, files logger.info("Committing storage, do not kill this process") # rename temp folder to processing self.load_storage.commit_temp_load_package(load_id) - # delete event files and count events to provide metrics - total_events = 0 - for event_file in chain.from_iterable(chunk_files): # flatten chunks - self.normalize_storage.storage.delete(event_file) - total_events += NormalizeStorage.get_events_count(event_file) + # delete item files to complete commit + for item_file in chain.from_iterable(chunk_files): # flatten chunks + self.normalize_storage.storage.delete(item_file) # log and update metrics logger.info(f"Chunk {load_id} processed") self.load_package_counter.labels(schema_name).inc() - self.event_counter.labels(schema_name).inc(total_events) - self.event_gauge.labels(schema_name).set(total_events) + self.item_counter.labels(schema_name).inc(total_items) + self.item_gauge.labels(schema_name).set(total_items) logger.metrics("Normalize metrics", extra=get_logging_extras( - [self.load_package_counter.labels(schema_name), self.event_counter.labels(schema_name), self.event_gauge.labels(schema_name)])) + [self.load_package_counter.labels(schema_name), self.item_counter.labels(schema_name), self.item_gauge.labels(schema_name)])) def spool_schema_files(self, schema_name: str, files: Sequence[str]) -> str: # normalized files will go here before being atomically renamed @@ -192,9 +227,11 @@ def spool_schema_files(self, schema_name: str, files: Sequence[str]) -> str: self.load_storage.create_temp_load_package(load_id) logger.info(f"Created temp load folder {load_id} on loading volume") + # if pool is not present use map_single method to run normalization in single process + map_parallel_f = self.map_parallel if self.pool else self.map_single try: # process parallel - self.spool_files(schema_name, load_id, self.map_parallel, files) + self.spool_files(schema_name, load_id, map_parallel_f, files) except CannotCoerceColumnException as exc: # schema conflicts resulting from parallel executing logger.warning(f"Parallel schema update conflict, switching to single thread ({str(exc)}") @@ -207,11 +244,10 @@ def spool_schema_files(self, schema_name: str, files: Sequence[str]) -> str: def run(self, pool: ProcessPool) -> TRunMetrics: # keep the pool in class instance self.pool = pool - logger.info("Running file normalizing") # list files and group by schema name, list must be sorted for group by to actually work files = self.normalize_storage.list_files_to_normalize_sorted() - logger.info(f"Found {len(files)} files, will process in chunks of {self.CONFIG.MAX_EVENTS_IN_CHUNK} of events") + logger.info(f"Found {len(files)} files") if len(files) == 0: return TRunMetrics(True, False, 0) # group files by schema @@ -220,20 +256,3 @@ def run(self, pool: ProcessPool) -> TRunMetrics: self.spool_schema_files(schema_name, list(files_in_schema)) # return info on still pending files (if extractor saved something in the meantime) return TRunMetrics(False, False, len(self.normalize_storage.list_files_to_normalize_sorted())) - - -def main(args: TRunnerArgs) -> int: - # initialize runner - C = configuration(args._asdict()) - initialize_runner(C) - # create objects and gauges - try: - n = Normalize(C, REGISTRY) - except Exception: - logger.exception("init module") - return -1 - return run_pool(C, n) - - -def run_main(args: TRunnerArgs) -> None: - exit(main(args)) diff --git a/experiments/pipeline/README.md b/dlt/pipeline/README.md similarity index 100% rename from experiments/pipeline/README.md rename to dlt/pipeline/README.md diff --git a/dlt/pipeline/__init__.py b/dlt/pipeline/__init__.py index 661ddc5ec9..ee30fa952f 100644 --- a/dlt/pipeline/__init__.py +++ b/dlt/pipeline/__init__.py @@ -1,4 +1,85 @@ -from dlt.common.schema import Schema # noqa: F401 -from dlt.pipeline.pipeline import Pipeline # noqa: F401 -from dlt.pipeline.typing import GCPPipelineCredentials, PostgresPipelineCredentials # noqa: F401 -from dlt.pipeline.exceptions import CannotRestorePipelineException # noqa: F401 +from typing import Union, cast + +from dlt.common.typing import TSecretValue, Any +from dlt.common.configuration import with_config +from dlt.common.configuration.container import Container +from dlt.common.destination import DestinationReference +from dlt.common.pipeline import PipelineContext, get_default_working_dir + +from dlt.pipeline.configuration import PipelineConfiguration +from dlt.pipeline.pipeline import Pipeline +from dlt.extract.decorators import source, resource + + +@with_config(spec=PipelineConfiguration, auto_namespace=True) +def pipeline( + pipeline_name: str = None, + working_dir: str = None, + pipeline_secret: TSecretValue = None, + destination: Union[None, str, DestinationReference] = None, + dataset_name: str = None, + import_schema_path: str = None, + export_schema_path: str = None, + always_drop_pipeline: bool = False, + **kwargs: Any +) -> Pipeline: + # call without parameters returns current pipeline + if not locals(): + context = Container()[PipelineContext] + # if pipeline instance is already active then return it, otherwise create a new one + if context.is_activated(): + return cast(Pipeline, context.pipeline()) + + # if working_dir not provided use temp folder + if not working_dir: + working_dir = get_default_working_dir() + + destination = DestinationReference.from_name(destination) + # create new pipeline instance + p = Pipeline(pipeline_name, working_dir, pipeline_secret, destination, dataset_name, import_schema_path, export_schema_path, always_drop_pipeline, False, kwargs["runtime"]) + # set it as current pipeline + Container()[PipelineContext].activate(p) + + return p + + +def restore( + pipeline_name: str = None, + working_dir: str = None, + pipeline_secret: TSecretValue = None +) -> Pipeline: + + _pipeline_name = pipeline_name + _working_dir = working_dir + + @with_config(spec=PipelineConfiguration, auto_namespace=True) + def _restore( + pipeline_name: str, + working_dir: str, + pipeline_secret: TSecretValue, + always_drop_pipeline: bool = False, + **kwargs: Any + ) -> Pipeline: + # use the outer pipeline name and working dir to override those from config in order to restore the requested state + pipeline_name = _pipeline_name or pipeline_name + working_dir = _working_dir or working_dir + + # if working_dir not provided use temp folder + if not working_dir: + working_dir = get_default_working_dir() + # create new pipeline instance + p = Pipeline(pipeline_name, working_dir, pipeline_secret, None, None, None, None, always_drop_pipeline, True, kwargs["runtime"]) + # set it as current pipeline + Container()[PipelineContext].activate(p) + return p + + return _restore(pipeline_name, working_dir, pipeline_secret) + + +# setup default pipeline in the container +Container()[PipelineContext] = PipelineContext(pipeline) + + +def run(source: Any, destination: Union[None, str, DestinationReference] = None) -> Pipeline: + destination = DestinationReference.from_name(destination) + return pipeline().run(source=source, destination=destination) diff --git a/dlt/pipeline/configuration.py b/dlt/pipeline/configuration.py new file mode 100644 index 0000000000..ccf78bf243 --- /dev/null +++ b/dlt/pipeline/configuration.py @@ -0,0 +1,34 @@ +from typing import ClassVar, Optional, TYPE_CHECKING +from typing_extensions import runtime + +from dlt.common.configuration import configspec +from dlt.common.configuration.specs import RunConfiguration, BaseConfiguration, ContainerInjectableContext +from dlt.common.typing import TSecretValue +from dlt.common.utils import uniq_id + +from dlt.pipeline.typing import TPipelineState + + +@configspec +class PipelineConfiguration(BaseConfiguration): + pipeline_name: Optional[str] = None + working_dir: Optional[str] = None + pipeline_secret: Optional[TSecretValue] = None + _runtime: RunConfiguration + + def check_integrity(self) -> None: + if not self.pipeline_secret: + self.pipeline_secret = TSecretValue(uniq_id()) + if not self.pipeline_name: + self.pipeline_name = self._runtime.pipeline_name + + +@configspec(init=True) +class StateInjectableContext(ContainerInjectableContext): + state: TPipelineState + + can_create_default: ClassVar[bool] = False + + if TYPE_CHECKING: + def __init__(self, state: TPipelineState = None) -> None: + ... diff --git a/dlt/pipeline/exceptions.py b/dlt/pipeline/exceptions.py index a250b31047..5655243214 100644 --- a/dlt/pipeline/exceptions.py +++ b/dlt/pipeline/exceptions.py @@ -1,7 +1,7 @@ from typing import Any, Sequence -from dlt.common.exceptions import DltException +from dlt.common.exceptions import DltException, ArgumentsOverloadException from dlt.common.telemetry import TRunMetrics -from dlt.pipeline.typing import TPipelineStage +from dlt.pipeline.typing import TPipelineStep class PipelineException(DltException): @@ -27,26 +27,20 @@ def _get_msg(self, appendix: str) -> str: def _to_pip_install(self) -> str: return "\n".join([f"pip install {d}" for d in self.dependencies]) - -class NoPipelineException(PipelineException): - def __init__(self) -> None: - super().__init__("Please create or restore pipeline before using this function") - - -class InvalidPipelineContextException(PipelineException): - def __init__(self) -> None: - super().__init__("There may be just one active pipeline in single python process. You may have switch between pipelines by restoring pipeline just before using load method") +class PipelineConfigMissing(PipelineException): + def __init__(self, config_elem: str, step: TPipelineStep, _help: str = None) -> None: + self.config_elem = config_elem + self.step = step + msg = f"Configuration element {config_elem} was not provided and {step} step cannot be executed" + if _help: + msg += f"\n{_help}\n" + super().__init__(msg) class CannotRestorePipelineException(PipelineException): - def __init__(self, reason: str) -> None: - super().__init__(reason) - - -class PipelineBackupNotFound(PipelineException): - def __init__(self, method: str) -> None: - self.method = method - super().__init__(f"Backup not found for method {method}") + def __init__(self, pipeline_name: str, working_dir: str, reason: str) -> None: + msg = f"Pipeline with name {pipeline_name} in working directory {working_dir} could not be restored: {reason}" + super().__init__(msg) class SqlClientNotAvailable(PipelineException): @@ -54,19 +48,9 @@ def __init__(self, client_type: str) -> None: super().__init__(f"SQL Client not available in {client_type}") -class InvalidIteratorException(PipelineException): - def __init__(self, iterator: Any) -> None: - super().__init__(f"Unsupported source iterator or iterable type: {type(iterator).__name__}") - - -class InvalidItemException(PipelineException): - def __init__(self, item: Any) -> None: - super().__init__(f"Source yielded unsupported item type: {type(item).__name}. Only dictionaries, sequences and deferred items allowed.") - - class PipelineStepFailed(PipelineException): - def __init__(self, stage: TPipelineStage, exception: BaseException, run_metrics: TRunMetrics) -> None: - self.stage = stage + def __init__(self, step: TPipelineStep, exception: BaseException, run_metrics: TRunMetrics) -> None: + self.stage = step self.exception = exception self.run_metrics = run_metrics - super().__init__(f"Pipeline execution failed at stage {stage} with exception:\n\n{type(exception)}\n{exception}") + super().__init__(f"Pipeline execution failed at stage {step} with exception:\n\n{type(exception)}\n{exception}") diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index e3502092c4..a633ebf49f 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -1,296 +1,477 @@ - +import os from contextlib import contextmanager -from copy import deepcopy -import yaml -from collections import abc -from dataclasses import asdict as dtc_asdict -import tempfile -import os.path -from typing import Any, Iterator, List, Sequence, Tuple, Callable -from prometheus_client import REGISTRY - -from dlt.common import json, sleep, signals, logger +from functools import wraps +from collections.abc import Sequence as C_Sequence +from typing import Any, Callable, ClassVar, List, Iterator, Mapping, Sequence, Tuple, get_type_hints, overload + +from dlt.common import json, logger, signals +from dlt.common.configuration.container import Container +from dlt.common.configuration.specs.config_namespace_context import ConfigNamespacesContext +from dlt.common.runners.runnable import Runnable +from dlt.common.schema.typing import TColumnSchema, TWriteDisposition +from dlt.common.storages.load_storage import LoadStorage +from dlt.common.typing import ParamSpec, TFun, TSecretValue + from dlt.common.runners import pool_runner as runner, TRunMetrics, initialize_runner -from dlt.common.configuration import PoolRunnerConfiguration, make_configuration -from dlt.common.file_storage import FileStorage -from dlt.common.schema import Schema, normalize_schema_name -from dlt.common.typing import DictStrAny, StrAny -from dlt.common.utils import uniq_id, is_interactive -from dlt.common.sources import DLT_METADATA_FIELD, TItem, with_table_name - -from dlt.extract.extractor_storage import ExtractorStorageBase -from dlt.load.client_base import SqlClientBase, SqlJobClientBase -from dlt.normalize.configuration import configuration as normalize_configuration -from dlt.load.configuration import configuration as loader_configuration +from dlt.common.storages import LiveSchemaStorage, NormalizeStorage + +from dlt.common.configuration import inject_namespace +from dlt.common.configuration.specs import RunConfiguration, NormalizeVolumeConfiguration, SchemaVolumeConfiguration, LoadVolumeConfiguration, PoolRunnerConfiguration +from dlt.common.destination import DestinationCapabilitiesContext, DestinationReference, JobClientBase, DestinationClientConfiguration, DestinationClientDwhConfiguration +from dlt.common.schema import Schema +from dlt.common.storages.file_storage import FileStorage +from dlt.common.utils import is_interactive + +from dlt.extract.extract import ExtractorStorage, extract +from dlt.extract.source import DltResource, DltSource from dlt.normalize import Normalize +from dlt.normalize.configuration import NormalizeConfiguration +from dlt.load.sql_client import SqlClientBase +from dlt.load.job_client_impl import SqlJobClientBase +from dlt.load.configuration import LoaderConfiguration from dlt.load import Load -from dlt.pipeline.exceptions import MissingDependencyException, NoPipelineException, PipelineStepFailed, CannotRestorePipelineException, SqlClientNotAvailable -from dlt.pipeline.typing import PipelineCredentials + +from dlt.pipeline.exceptions import CannotRestorePipelineException, PipelineConfigMissing, MissingDependencyException, PipelineStepFailed, SqlClientNotAvailable +from dlt.pipeline.typing import TPipelineStep, TPipelineState +from dlt.pipeline.configuration import StateInjectableContext + + +def with_state_sync(f: TFun) -> TFun: + + @wraps(f) + def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: + # backup and restore state + with self._managed_state() as state: + # add the state to container as a context + with self._container.injectable_context(StateInjectableContext(state=state)): + return f(self, *args, **kwargs) + + return _wrap # type: ignore + +def with_schemas_sync(f: TFun) -> TFun: + + @wraps(f) + def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: + for name in self._schema_storage.live_schemas: + # refresh live schemas in storage or import schema path + self._schema_storage.commit_live_schema(name) + return f(self, *args, **kwargs) + + return _wrap # type: ignore + +def with_config_namespace(namespaces: Tuple[str, ...]) -> Callable[[TFun], TFun]: + + def decorator(f: TFun) -> TFun: + + @wraps(f) + def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: + # add namespace context to the container to be used by all configuration without explicit namespaces resolution + with inject_namespace(ConfigNamespacesContext(pipeline_name=self.pipeline_name, namespaces=namespaces)): + return f(self, *args, **kwargs) + + return _wrap # type: ignore + + return decorator class Pipeline: - def __init__(self, pipeline_name: str, log_level: str = "INFO") -> None: - self.pipeline_name = pipeline_name - self.root_path: str = None - self.export_schema_path: str = None - self.import_schema_path: str = None - self.root_storage: FileStorage = None - self.credentials: PipelineCredentials = None - self.extractor_storage: ExtractorStorageBase = None - self.default_schema_name: str = None - self.state: DictStrAny = {} - - # addresses of pipeline components to be verified before they are run - self._normalize_instance: Normalize = None - self._loader_instance: Load = None - - # patch config and initialize pipeline - self.C = make_configuration(PoolRunnerConfiguration, PoolRunnerConfiguration, initial_values={ - "PIPELINE_NAME": pipeline_name, - "LOG_LEVEL": log_level, - "POOL_TYPE": "None", - "IS_SINGLE_RUN": True, - "WAIT_RUNS": 0, - "EXIT_ON_EXCEPTION": True, - }) - initialize_runner(self.C) - - def create_pipeline( + + STATE_FILE: ClassVar[str] = "state.json" + STATE_PROPS: ClassVar[List[str]] = list(get_type_hints(TPipelineState).keys()) + + pipeline_name: str + default_schema_name: str + always_drop_pipeline: bool + working_dir: str + pipeline_root: str + destination: DestinationReference + dataset_name: str + + def __init__( + self, + pipeline_name: str, + working_dir: str, + pipeline_secret: TSecretValue, + destination: DestinationReference, + dataset_name: str, + import_schema_path: str, + export_schema_path: str, + always_drop_pipeline: bool, + must_restore_pipeline: bool, + runtime: RunConfiguration + ) -> None: + self.pipeline_secret = pipeline_secret + self.runtime_config = runtime + + self._container = Container() + # self._state: TPipelineState = {} # type: ignore + self._pipeline_storage: FileStorage = None + self._schema_storage: LiveSchemaStorage = None + self._schema_storage_config: SchemaVolumeConfiguration = None + self._normalize_storage_config: NormalizeVolumeConfiguration = None + self._load_storage_config: LoadVolumeConfiguration = None + + initialize_runner(self.runtime_config) + # initialize pipeline working dir + self._init_working_dir(pipeline_name, working_dir) + # initialize or restore state + with self._managed_state(): + # see if state didn't change the pipeline name + if pipeline_name != self.pipeline_name: + raise CannotRestorePipelineException(pipeline_name, working_dir, f"working directory contains state for pipeline with name {self.pipeline_name}") + # at this moment state is recovered so we overwrite the state with the values from init + self.destination = destination or self.destination # changing the destination could be dangerous if pipeline has not loaded items + self.dataset_name = dataset_name or self.dataset_name + self.always_drop_pipeline = always_drop_pipeline + self._configure(import_schema_path, export_schema_path, must_restore_pipeline) + + def drop(self) -> "Pipeline": + """Deletes existing pipeline state, schemas and drops datasets at the destination if present""" + if self.destination: + # drop the data for all known schemas + for schema in self._schema_storage: + with self._get_destination_client(self._schema_storage.load_schema(schema)) as client: + client.initialize_storage(wipe_data=True) + # reset the pipeline working dir + self._create_pipeline() + # clone the pipeline + return Pipeline( + self.pipeline_name, + self.working_dir, + self.pipeline_secret, + self.destination, + self.dataset_name, + self._schema_storage.config.import_schema_path, + self._schema_storage.config.export_schema_path, + self.always_drop_pipeline, + True, + self.runtime_config + ) + + + # @overload + # def extract( + # self, + # data: Union[Iterator[TResolvableDataItem], Iterable[TResolvableDataItem]], + # table_name = None, + # write_disposition = None, + # parent = None, + # columns = None, + # max_parallel_data_items: int = 20, + # schema: Schema = None + # ) -> None: + # ... + + # @overload + # def extract( + # self, + # data: DltSource, + # max_parallel_iterators: int = 1, + # max_parallel_data_items: int = 20, + # schema: Schema = None + # ) -> None: + # ... + + @with_schemas_sync + @with_state_sync + @with_config_namespace(("extract",)) + def extract( self, - credentials: PipelineCredentials, - working_dir: str = None, + data: Any, + table_name: str = None, + parent_table_name: str = None, + write_disposition: TWriteDisposition = None, + columns: Sequence[TColumnSchema] = None, schema: Schema = None, - import_schema_path: str = None, - export_schema_path: str = None + *, + max_parallel_items: int = 100, + workers: int = 5 ) -> None: - # initialize root storage - if not working_dir: - working_dir = tempfile.mkdtemp() - self.root_storage = FileStorage(working_dir, makedirs=True) - self.export_schema_path = export_schema_path - self.import_schema_path = import_schema_path - - # check if directory contains restorable pipeline + + + # def has_hint_args() -> bool: + # return table_name or parent_table_name or write_disposition or schema + + def apply_hint_args(resource: DltResource) -> None: + columns_dict = None + if columns: + columns_dict = {c["name"]:c for c in columns} + resource.apply_hints(table_name, parent_table_name, write_disposition, columns_dict) + + def choose_schema() -> Schema: + if schema: + return schema + if self.default_schema_name: + return self.default_schema + return Schema(self.pipeline_name) + + # a list of sources or a list of resources may be passed as data + sources: List[DltSource] = [] + + def item_to_source(data_item: Any) -> DltSource: + if isinstance(data_item, DltSource): + # if schema is explicit then override source schema + if schema: + data_item.schema = schema + # try to apply hints to resources + resources = data_item.resources.values() + for r in resources: + apply_hint_args(r) + return data_item + + if isinstance(data_item, DltResource): + # apply hints + apply_hint_args(data_item) + # package resource in source + return DltSource(choose_schema(), [data_item]) + + # iterator/iterable/generator + # create resource first without table template + resource = DltResource.from_data(data_item, name=table_name) + # apply hints + apply_hint_args(resource) + # wrap resource in source + return DltSource(choose_schema(), [resource]) + + if isinstance(data, C_Sequence) and len(data) > 0: + # if first element is source or resource + if isinstance(data[0], DltResource): + sources.append(item_to_source(DltSource(choose_schema(), data))) + elif isinstance(data[0], DltSource): + for s in data: + sources.append(item_to_source(s)) + else: + sources.append(item_to_source(data)) + else: + sources.append(item_to_source(data)) + try: - self._restore_state() - # wipe out the old pipeline - self.root_storage.delete_folder("", recursively=True) - self.root_storage.create_folder("") - except FileNotFoundError: - pass + # extract all sources + for s in sources: + self._extract_source(s, max_parallel_items, workers) + except Exception as exc: + # TODO: provide metrics from extractor + raise PipelineStepFailed("extract", exc, runner.LAST_RUN_METRICS) from exc - self.root_path = self.root_storage.storage_path - self.credentials = credentials - self._load_modules() - self.extractor_storage = ExtractorStorageBase( - "1.0.0", - True, - FileStorage(os.path.join(self.root_path, "extractor"), makedirs=True), - self._normalize_instance.normalize_storage) - # create new schema if no default supplied - if schema is None: - # try to load schema, that will also import it - schema_name = normalize_schema_name(self.pipeline_name) - try: - schema = self._normalize_instance.schema_storage.load_schema(schema_name) - except FileNotFoundError: - # create new empty schema - schema = Schema(schema_name) - # initialize empty state, this must be last operation when creating pipeline so restore reads only fully created ones - with self._managed_state(): - self.state = { - # "default_schema_name": default_schema_name, - "pipeline_name": self.pipeline_name, - # TODO: must come from resolved configuration - "loader_client_type": credentials.CLIENT_TYPE, - # TODO: must take schema prefix from resolved configuration - "loader_schema_prefix": credentials.default_dataset - } - # persist schema with the pipeline - self.set_default_schema(schema) - - def restore_pipeline( + + @with_schemas_sync + @with_config_namespace(("normalize",)) + def normalize(self, workers: int = 1) -> None: + if is_interactive() and workers > 1: + raise NotImplementedError("Do not use normalize workers in interactive mode ie. in notebook") + # check if any schema is present, if not then no data was extracted + if not self.default_schema_name: + return + + # get destination capabilities + destination_caps = self._get_destination_capabilities() + # create default normalize config + normalize_config = NormalizeConfiguration( + is_single_run=True, + exit_on_exception=True, + workers=workers, + pool_type="none" if workers == 1 else "process", + schema_storage_config=self._schema_storage_config, + normalize_storage_config=self._normalize_storage_config, + load_storage_config=self._load_storage_config + ) + # run with destination context + with self._container.injectable_context(destination_caps): + # shares schema storage with the pipeline so we do not need to install + normalize = Normalize(config=normalize_config, schema_storage=self._schema_storage) + self._run_step_in_pool("normalize", normalize, normalize.config) + + @with_schemas_sync + @with_state_sync + @with_config_namespace(("load",)) + def load( self, - credentials: PipelineCredentials, - working_dir: str, - import_schema_path: str = None, - export_schema_path: str = None + destination: DestinationReference = None, + dataset_name: str = None, + credentials: Any = None, + # raise_on_failed_jobs = False, + # raise_on_incompatible_schema = False, + always_wipe_storage: bool = False, + *, + workers: int = 20 ) -> None: - try: - # do not create extractor dir - it must exist - self.root_storage = FileStorage(working_dir, makedirs=False) - # restore state, this must be a first operation when restoring pipeline - try: - self._restore_state() - except FileNotFoundError: - raise CannotRestorePipelineException(f"Cannot find a valid pipeline in {working_dir}") - restored_name = self.state["pipeline_name"] - if self.pipeline_name != restored_name: - raise CannotRestorePipelineException(f"Expected pipeline {self.pipeline_name}, found {restored_name} pipeline instead") - self.default_schema_name = self.state["default_schema_name"] - if not credentials.default_dataset: - credentials.default_dataset = self.state["loader_schema_prefix"] - self.root_path = self.root_storage.storage_path - self.credentials = credentials - self.export_schema_path = export_schema_path - self.import_schema_path = import_schema_path - self._load_modules() - # schema must exist - try: - self.get_default_schema() - except (FileNotFoundError): - raise CannotRestorePipelineException(f"Default schema with name {self.default_schema_name} not found") - self.extractor_storage = ExtractorStorageBase( - "1.0.0", - True, - FileStorage(os.path.join(self.root_path, "extractor"), makedirs=False), - self._normalize_instance.normalize_storage - ) - except CannotRestorePipelineException: - raise - def extract(self, items: Iterator[TItem], schema_name: str = None, table_name: str = None) -> None: - # check if iterator or iterable is supported - # if isinstance(items, str) or isinstance(items, dict) or not - # TODO: check if schema exists - with self._managed_state(): - default_table_name = table_name or self.pipeline_name - # TODO: this is not very effective - we consume iterator right away, better implementation needed where we stream iterator to files directly - all_items: List[DictStrAny] = [] - for item in items: - # dispatch items by type - if callable(item): - item = item() - if isinstance(item, dict): - all_items.append(item) - elif isinstance(item, abc.Sequence): - all_items.extend(item) - # react to CTRL-C and shutdowns from controllers - signals.raise_if_signalled() - - try: - self._extract_iterator(default_table_name, all_items) - except Exception: - raise PipelineStepFailed("extract", self.last_run_exception, runner.LAST_RUN_METRICS) - - def normalize(self, workers: int = 1, max_events_in_chunk: int = 100000) -> int: - if is_interactive() and workers > 1: - raise NotImplementedError("Do not use workers in interactive mode ie. in notebook") - self._verify_normalize_instance() - # set runtime parameters - self._normalize_instance.CONFIG.WORKERS = workers - self._normalize_instance.CONFIG.MAX_EVENTS_IN_CHUNK = max_events_in_chunk - # switch to thread pool for single worker - self._normalize_instance.CONFIG.POOL_TYPE = "thread" if workers == 1 else "process" - try: - ec = runner.run_pool(self._normalize_instance.CONFIG, self._normalize_instance) - # in any other case we raise if runner exited with status failed - if runner.LAST_RUN_METRICS.has_failed: - raise PipelineStepFailed("normalize", self.last_run_exception, runner.LAST_RUN_METRICS) - return ec - except Exception as r_ex: - # if EXIT_ON_EXCEPTION flag is set, exception will bubble up directly - raise PipelineStepFailed("normalize", self.last_run_exception, runner.LAST_RUN_METRICS) from r_ex + # set destination and default dataset if provided + self.destination = destination or self.destination + self.dataset_name = dataset_name or self.dataset_name + # check if any schema is present, if not then no data was extracted + if not self.default_schema_name: + return + + # make sure that destination is set and client is importable and can be instantiated + client_initial_config = self._get_destination_client_initial_config(credentials) + self._get_destination_client(self.default_schema, client_initial_config) + + # create initial loader config and the loader + load_config = LoaderConfiguration( + is_single_run=True, + exit_on_exception=True, + workers=workers, + always_wipe_storage=always_wipe_storage or self.always_drop_pipeline, + load_storage_config=self._load_storage_config + ) + load = Load(self.destination, is_storage_owner=False, config=load_config, initial_client_config=client_initial_config) + self._run_step_in_pool("load", load, load.config) + + @with_config_namespace(("run",)) + def run( + self, + source: Any = None, + destination: DestinationReference = None, + dataset_name: str = None, + table_name: str = None, + write_disposition: TWriteDisposition = None, + columns: Sequence[TColumnSchema] = None, + schema: Schema = None + ) -> None: + # set destination and default dataset if provided + self.destination = destination or self.destination + self.dataset_name = dataset_name or self.dataset_name + # normalize and load pending data + self.normalize() + self.load(destination, dataset_name) - def load(self, max_parallel_loads: int = 20) -> int: - self._verify_loader_instance() - self._loader_instance.CONFIG.WORKERS = max_parallel_loads - self._loader_instance.load_client_cls.CONFIG.DEFAULT_SCHEMA_NAME = self.default_schema_name # type: ignore - try: - ec = runner.run_pool(self._loader_instance.CONFIG, self._loader_instance) - # in any other case we raise if runner exited with status failed - if runner.LAST_RUN_METRICS.has_failed: - raise PipelineStepFailed("load", self.last_run_exception, runner.LAST_RUN_METRICS) - return ec - except Exception as r_ex: - # if EXIT_ON_EXCEPTION flag is set, exception will bubble up directly - raise PipelineStepFailed("load", self.last_run_exception, runner.LAST_RUN_METRICS) from r_ex + # extract from the source + if source: + self.extract(source, table_name, write_disposition, None, columns, schema) + self.normalize() + self.load(destination, dataset_name) - def flush(self) -> None: - self.normalize() - self.load() + @property + def schemas(self) -> Mapping[str, Schema]: + return self._schema_storage @property - def working_dir(self) -> str: - return os.path.abspath(self.root_path) + def default_schema(self) -> Schema: + return self.schemas[self.default_schema_name] @property def last_run_exception(self) -> BaseException: return runner.LAST_RUN_EXCEPTION - def list_extracted_loads(self) -> Sequence[str]: - self._verify_loader_instance() - return self._normalize_instance.normalize_storage.list_files_to_normalize_sorted() + def list_extracted_resources(self) -> Sequence[str]: + return self._get_normalize_storage().list_files_to_normalize_sorted() - def list_normalized_loads(self) -> Sequence[str]: - self._verify_loader_instance() - return self._loader_instance.load_storage.list_packages() + def list_normalized_load_packages(self) -> Sequence[str]: + return self._get_load_storage().list_packages() - def list_completed_loads(self) -> Sequence[str]: - self._verify_loader_instance() - return self._loader_instance.load_storage.list_completed_packages() + def list_completed_load_packages(self) -> Sequence[str]: + return self._get_load_storage().list_completed_packages() - def list_failed_jobs(self, load_id: str) -> Sequence[Tuple[str, str]]: - self._verify_loader_instance() + def list_failed_jobs_in_package(self, load_id: str) -> Sequence[Tuple[str, str]]: + storage = self._get_load_storage() failed_jobs: List[Tuple[str, str]] = [] - for file in self._loader_instance.load_storage.list_completed_failed_jobs(load_id): + for file in storage.list_completed_failed_jobs(load_id): if not file.endswith(".exception"): try: - failed_message = self._loader_instance.load_storage.storage.load(file + ".exception") + failed_message = storage.storage.load(file + ".exception") except FileNotFoundError: failed_message = None - failed_jobs.append((file, failed_message)) + failed_jobs.append((storage.storage.make_full_path(file), failed_message)) return failed_jobs - def get_default_schema(self) -> Schema: - self._verify_normalize_instance() - return self._normalize_instance.schema_storage.load_schema(self.default_schema_name) - - def set_default_schema(self, new_schema: Schema) -> None: - if self.default_schema_name: - # delete old schema - try: - self._normalize_instance.schema_storage.remove_schema(self.default_schema_name) - self.default_schema_name = None - except FileNotFoundError: - pass - # save new schema - self._normalize_instance.schema_storage.save_schema(new_schema) - self.default_schema_name = new_schema.name - with self._managed_state(): - self.state["default_schema_name"] = self.default_schema_name - - def add_schema(self, aux_schema: Schema) -> None: - self._normalize_instance.schema_storage.save_schema(aux_schema) - - def get_schema(self, name: str) -> Schema: - return self._normalize_instance.schema_storage.load_schema(name) - - def remove_schema(self, name: str) -> None: - self._normalize_instance.schema_storage.remove_schema(name) - - def sync_schema(self) -> None: - self._verify_loader_instance() - schema = self._normalize_instance.schema_storage.load_schema(self.default_schema_name) - with self._loader_instance.load_client_cls(schema) as client: - client.initialize_storage() + def sync_schema(self, schema_name: str = None) -> None: + with self._get_destination_client(self.schemas[schema_name]) as client: + client.initialize_storage(wipe_data=self.always_drop_pipeline) client.update_storage_schema() def sql_client(self, schema_name: str = None) -> SqlClientBase[Any]: - self._verify_loader_instance() - schema = self._normalize_instance.schema_storage.load_schema(schema_name or self.default_schema_name) - with self._loader_instance.load_client_cls(schema) as c: - if isinstance(c, SqlJobClientBase): - return c.sql_client + with self._get_destination_client(self.schemas[schema_name]) as client: + if isinstance(client, SqlJobClientBase): + return client.sql_client else: - raise SqlClientNotAvailable(self._loader_instance.CONFIG.CLIENT_TYPE) + raise SqlClientNotAvailable(self.destination.__name__) + + def _get_normalize_storage(self) -> NormalizeStorage: + return NormalizeStorage(True, self._normalize_storage_config) - def run_in_pool(self, run_f: Callable[..., Any]) -> int: - # internal runners should work in single mode - self._loader_instance.CONFIG.IS_SINGLE_RUN = True - self._loader_instance.CONFIG.EXIT_ON_EXCEPTION = True - self._normalize_instance.CONFIG.IS_SINGLE_RUN = True - self._normalize_instance.CONFIG.EXIT_ON_EXCEPTION = True + def _get_load_storage(self) -> LoadStorage: + caps = self._get_destination_capabilities() + return LoadStorage(True, caps.preferred_loader_file_format, caps.supported_loader_file_formats, self._load_storage_config) + + def _init_working_dir(self, pipeline_name: str, working_dir: str) -> None: + self.pipeline_name = pipeline_name + self.working_dir = working_dir + # compute the folder that keeps all of the pipeline state + FileStorage.validate_file_name_component(self.pipeline_name) + self.pipeline_root = os.path.join(working_dir, pipeline_name) + # create pipeline working dir + self._pipeline_storage = FileStorage(self.pipeline_root, makedirs=False) + + def _configure(self, import_schema_path: str, export_schema_path: str, must_restore_pipeline: bool) -> None: + # create default configs + self._schema_storage_config = SchemaVolumeConfiguration( + schema_volume_path=os.path.join(self.pipeline_root, "schemas"), + import_schema_path=import_schema_path, + export_schema_path=export_schema_path + ) + self._normalize_storage_config = NormalizeVolumeConfiguration(normalize_volume_path=os.path.join(self.pipeline_root, "normalize")) + self._load_storage_config = LoadVolumeConfiguration(load_volume_path=os.path.join(self.pipeline_root, "load"),) + + # are we running again? + has_state = self._pipeline_storage.has_file(Pipeline.STATE_FILE) + if must_restore_pipeline and not has_state: + raise CannotRestorePipelineException(self.pipeline_name, self.working_dir, f"the pipeline was not found in {self.pipeline_root}.") + + # restore pipeline if folder exists and contains state + if has_state and (not self.always_drop_pipeline or must_restore_pipeline): + self._restore_pipeline() + else: + # this will erase the existing working folder + self._create_pipeline() + + # create schema storage + self._schema_storage = LiveSchemaStorage(self._schema_storage_config, makedirs=True) + + def _create_pipeline(self) -> None: + # kill everything inside the working folder + if self._pipeline_storage.has_folder(""): + self._pipeline_storage.delete_folder("", recursively=True) + self._pipeline_storage.create_folder("", exists_ok=False) + + def _restore_pipeline(self) -> None: + pass + + def _extract_source(self, source: DltSource, max_parallel_items: int, workers: int) -> None: + # discover the schema from source + source_schema = source.discover_schema() + + # iterate over all items in the pipeline and update the schema if dynamic table hints were present + storage = ExtractorStorage(self._normalize_storage_config) + for _, partials in extract(source, storage, max_parallel_items=max_parallel_items, workers=workers).items(): + for partial in partials: + source_schema.update_schema(source_schema.normalize_table_identifiers(partial)) + + # if source schema does not exist in the pipeline + if source_schema.name not in self._schema_storage: + # possibly initialize the import schema if it is a new schema + self._schema_storage.initialize_import_if_new(source_schema) + # save schema into the pipeline + self._schema_storage.save_schema(source_schema) + # and set as default if this is first schema in pipeline + if not self.default_schema_name: + self.default_schema_name = source_schema.name + + + def _run_step_in_pool(self, step: TPipelineStep, runnable: Runnable[Any], config: PoolRunnerConfiguration) -> int: + try: + ec = runner.run_pool(config, runnable) + # in any other case we raise if runner exited with status failed + if runner.LAST_RUN_METRICS.has_failed: + raise PipelineStepFailed(step, self.last_run_exception, runner.LAST_RUN_METRICS) + return ec + except Exception as r_ex: + # if EXIT_ON_EXCEPTION flag is set, exception will bubble up directly + raise PipelineStepFailed(step, self.last_run_exception, runner.LAST_RUN_METRICS) from r_ex + finally: + signals.raise_if_signalled() + + def _run_f_in_pool(self, run_f: Callable[..., Any], config: PoolRunnerConfiguration) -> int: def _run(_: Any) -> TRunMetrics: rv = run_f() @@ -303,7 +484,7 @@ def _run(_: Any) -> TRunMetrics: return TRunMetrics(False, False, int(pending)) # run the fun - ec = runner.run_pool(self.C, _run) + ec = runner.run_pool(config, _run) # ec > 0 - signalled # -1 - runner was not able to start @@ -311,116 +492,83 @@ def _run(_: Any) -> TRunMetrics: raise self.last_run_exception return ec + def _get_destination_client_initial_config(self, credentials: Any = None) -> DestinationClientConfiguration: + if not self.destination: + raise PipelineConfigMissing( + "destination", + "load", + "Please provide `destination` argument to `config` or `load` method or via pipeline config file or environment var." + ) + dataset_name = self._get_dataset_name() + # create initial destination client config + client_spec = self.destination.spec() + if issubclass(client_spec, DestinationClientDwhConfiguration): + # client support schemas and datasets + return client_spec(dataset_name=dataset_name, default_schema_name=self.default_schema_name, credentials=credentials) + else: + return client_spec(credentials=credentials) - def _configure_normalize(self) -> None: - # create normalize config - normalize_initial = { - "NORMALIZE_VOLUME_PATH": os.path.join(self.root_path, "normalize"), - "SCHEMA_VOLUME_PATH": os.path.join(self.root_path, "schemas"), - "EXPORT_SCHEMA_PATH": os.path.abspath(self.export_schema_path) if self.export_schema_path else None, - "IMPORT_SCHEMA_PATH": os.path.abspath(self.import_schema_path) if self.import_schema_path else None, - "LOADER_FILE_FORMAT": self._loader_instance.load_client_cls.capabilities()["preferred_loader_file_format"], - "ADD_EVENT_JSON": False - } - normalize_initial.update(self._configure_runner()) - C = normalize_configuration(initial_values=normalize_initial) - # shares schema storage with the pipeline so we do not need to install - self._normalize_instance = Normalize(C) - - def _configure_load(self) -> None: - # use credentials to populate loader client config, it includes also client type - loader_client_initial = dtc_asdict(self.credentials) - loader_client_initial["DEFAULT_SCHEMA_NAME"] = self.default_schema_name - # but client type must be passed to loader config - loader_initial = {"CLIENT_TYPE": loader_client_initial["CLIENT_TYPE"]} - loader_initial.update(self._configure_runner()) - loader_initial["DELETE_COMPLETED_JOBS"] = True + def _get_destination_client(self, schema: Schema, initial_config: DestinationClientConfiguration = None) -> JobClientBase: try: - C = loader_configuration(initial_values=loader_initial) - self._loader_instance = Load(C, REGISTRY, client_initial_values=loader_client_initial, is_storage_owner=True) + # config is not provided then get it with injected credentials + if not initial_config: + initial_config = self._get_destination_client_initial_config() + return self.destination.client(schema, initial_config) except ImportError: + client_spec = self.destination.spec() raise MissingDependencyException( - f"{self.credentials.CLIENT_TYPE} destination", - [f"python-dlt[{self.credentials.CLIENT_TYPE}]"], - "Dependencies for specific destination are available as extras of python-dlt" + f"{client_spec.destination_name} destination", + [f"python-dlt[{client_spec.destination_name}]"], + "Dependencies for specific destinations are available as extras of python-dlt" ) - def _verify_loader_instance(self) -> None: - if self._loader_instance is None: - raise NoPipelineException() - - def _verify_normalize_instance(self) -> None: - if self._loader_instance is None: - raise NoPipelineException() - - def _configure_runner(self) -> StrAny: - return { - "PIPELINE_NAME": self.pipeline_name, - "IS_SINGLE_RUN": True, - "WAIT_RUNS": 0, - "EXIT_ON_EXCEPTION": True, - "LOAD_VOLUME_PATH": os.path.join(self.root_path, "normalized") - } - - def _load_modules(self) -> None: - # configure loader - self._configure_load() - # configure normalize - self._configure_normalize() - - def _extract_iterator(self, default_table_name: str, items: Sequence[DictStrAny]) -> None: - try: - for idx, i in enumerate(items): - if not isinstance(i, dict): - # TODO: convert non dict types into dict - items[idx] = i = {"v": i} - if DLT_METADATA_FIELD not in i or i.get(DLT_METADATA_FIELD, None) is None: - # set default table name - with_table_name(i, default_table_name) - - load_id = uniq_id() - self.extractor_storage.save_json(f"{load_id}.json", items) - self.extractor_storage.commit_events( - self.default_schema_name, - self.extractor_storage.storage._make_path(f"{load_id}.json"), - default_table_name, - len(items), - load_id - ) + def _get_destination_capabilities(self) -> DestinationCapabilitiesContext: + if not self.destination: + raise PipelineConfigMissing( + "destination", + "normalize", + "Please provide `destination` argument to `config` or `load` method or via pipeline config file or environment var." + ) + return self.destination.capabilities() - runner.LAST_RUN_METRICS = TRunMetrics(was_idle=False, has_failed=False, pending_items=0) - except Exception as ex: - logger.exception("extracting iterator failed") - runner.LAST_RUN_METRICS = TRunMetrics(was_idle=False, has_failed=True, pending_items=0) - runner.LAST_RUN_EXCEPTION = ex - raise + def _get_dataset_name(self) -> str: + return self.dataset_name or self.pipeline_name + + def _get_state(self) -> TPipelineState: + try: + state: TPipelineState = json.loads(self._pipeline_storage.load(Pipeline.STATE_FILE)) + except FileNotFoundError: + state = {} + return state @contextmanager - def _managed_state(self) -> Iterator[None]: - backup_state = deepcopy(self.state) + def _managed_state(self) -> Iterator[TPipelineState]: + # load current state + state = self._get_state() + # write props to pipeline variables + for prop in Pipeline.STATE_PROPS: + setattr(self, prop, state.get(prop)) + if "destination" in state: + self.destination = DestinationReference.from_name(self.destination) + try: - yield + yield state except Exception: # restore old state - self.state.clear() - self.state.update(backup_state) + # currently do nothing - state is not preserved in memory, only saved raise else: + # update state props + for prop in Pipeline.STATE_PROPS: + state[prop] = getattr(self, prop) # type: ignore + if self.destination: + state["destination"] = self.destination.__name__ + + # load state from storage to be merged with pipeline changes, currently we assume no parallel changes + # compare backup and new state, save only if different + backup_state = self._get_state() + new_state = json.dumps(state, sort_keys=True) + old_state = json.dumps(backup_state, sort_keys=True) # persist old state - self.root_storage.save("state.json", json.dumps(self.state)) - - def _restore_state(self) -> None: - self.state.clear() - restored_state: DictStrAny = json.loads(self.root_storage.load("state.json")) - self.state.update(restored_state) - - @staticmethod - def save_schema_to_file(file_name: str, schema: Schema, remove_defaults: bool = True) -> None: - with open(file_name, "w", encoding="utf-8") as f: - f.write(schema.to_pretty_yaml(remove_defaults=remove_defaults)) - - @staticmethod - def load_schema_from_file(file_name: str) -> Schema: - with open(file_name, "r", encoding="utf-8") as f: - schema_dict: DictStrAny = yaml.safe_load(f) - return Schema.from_dict(schema_dict) + if new_state != old_state: + self._pipeline_storage.save(Pipeline.STATE_FILE, new_state) diff --git a/dlt/pipeline/typing.py b/dlt/pipeline/typing.py index f269a0c17e..8a1248e93b 100644 --- a/dlt/pipeline/typing.py +++ b/dlt/pipeline/typing.py @@ -1,98 +1,16 @@ +from typing import Literal, TypedDict, Optional -from typing import Literal, Type, Any -from dataclasses import dataclass, fields as dtc_fields -from dlt.common import json -from dlt.common.typing import StrAny, TSecretValue +TPipelineStep = Literal["extract", "normalize", "load"] -TLoaderType = Literal["bigquery", "redshift", "dummy"] -TPipelineStage = Literal["extract", "normalize", "load"] +class TPipelineState(TypedDict, total=False): + pipeline_name: str + dataset_name: str + default_schema_name: Optional[str] + destination: Optional[str] -# extractor generator yields functions that returns list of items of the type (table) when called -# this allows generator to implement retry logic -# TExtractorItem = Callable[[], Iterator[StrAny]] -# # extractor generator yields tuples: (type of the item (table name), function defined above) -# TExtractorItemWithTable = Tuple[str, TExtractorItem] -# TExtractorGenerator = Callable[[DictStrAny], Iterator[TExtractorItemWithTable]] +# TSourceState = NewType("TSourceState", DictStrAny) -@dataclass -class PipelineCredentials: - CLIENT_TYPE: TLoaderType - - @property - def default_dataset(self) -> str: - pass - - @default_dataset.setter - def default_dataset(self, new_value: str) -> None: - pass - -@dataclass -class GCPPipelineCredentials(PipelineCredentials): - PROJECT_ID: str = None - DEFAULT_DATASET: str = None - CLIENT_EMAIL: str = None - PRIVATE_KEY: TSecretValue = None - LOCATION: str = "US" - CRED_TYPE: str = "service_account" - TOKEN_URI: str = "https://oauth2.googleapis.com/token" - HTTP_TIMEOUT: float = 15.0 - RETRY_DEADLINE: float = 600 - - @property - def default_dataset(self) -> str: - return self.DEFAULT_DATASET - - @default_dataset.setter - def default_dataset(self, new_value: str) -> None: - self.DEFAULT_DATASET = new_value - - @classmethod - def from_services_dict(cls, services: StrAny, dataset_prefix: str, location: str = "US") -> "GCPPipelineCredentials": - assert dataset_prefix is not None - return cls("bigquery", services["project_id"], dataset_prefix, services["client_email"], services["private_key"], location or cls.LOCATION) - - @classmethod - def from_services_file(cls, services_path: str, dataset_prefix: str, location: str = "US") -> "GCPPipelineCredentials": - with open(services_path, "r", encoding="utf-8") as f: - services = json.load(f) - return GCPPipelineCredentials.from_services_dict(services, dataset_prefix, location) - - @classmethod - def default_credentials(cls, dataset_prefix: str, project_id: str = None, location: str = None) -> "GCPPipelineCredentials": - return cls("bigquery", project_id, dataset_prefix, None, None, location or cls.LOCATION) - - -@dataclass -class PostgresPipelineCredentials(PipelineCredentials): - DBNAME: str = None - DEFAULT_DATASET: str = None - USER: str = None - HOST: str = None - PASSWORD: TSecretValue = None - PORT: int = 5439 - CONNECT_TIMEOUT: int = 15 - - @property - def default_dataset(self) -> str: - return self.DEFAULT_DATASET - - @default_dataset.setter - def default_dataset(self, new_value: str) -> None: - self.DEFAULT_DATASET = new_value - - -def credentials_from_dict(credentials: StrAny) -> PipelineCredentials: - - def ignore_unknown_props(typ_: Type[Any], props: StrAny) -> StrAny: - fields = {f.name: f for f in dtc_fields(typ_)} - return {k:v for k,v in props.items() if k in fields} - - client_type = credentials.get("CLIENT_TYPE") - if client_type == "bigquery": - return GCPPipelineCredentials(**ignore_unknown_props(GCPPipelineCredentials, credentials)) - elif client_type == "redshift": - return PostgresPipelineCredentials(**ignore_unknown_props(PostgresPipelineCredentials, credentials)) - else: - raise ValueError(f"CLIENT_TYPE: {client_type}") +# class TPipelineState() +# sources: Dict[str, TSourceState] diff --git a/examples/google_drive_csv.py b/examples/google_drive_csv.py index 2b2f120e8e..36f129ea7a 100644 --- a/examples/google_drive_csv.py +++ b/examples/google_drive_csv.py @@ -62,8 +62,7 @@ def download_csv_as_json(file_id: str, csv_options: StrAny = None) -> Iterator[D # SCHEMA CREATION data_schema = None - # data_schema_file_path = f"/Users/adrian/PycharmProjects/sv/dlt/examples/schemas/inferred_drive_csv_{file_id}_schema.yml" - data_schema_file_path = f"examples/schemas/inferred_drive_csv_{file_id}_schema.yml" + data_schema_file_path = f"examples/schemas/inferred_drive_csv_{file_id}.schema.yml" credentials = GCPPipelineCredentials.from_services_file(gcp_credential_json_file_path, schema_prefix) diff --git a/examples/schemas/discord_schema.yml b/examples/schemas/discord.schema.yml similarity index 100% rename from examples/schemas/discord_schema.yml rename to examples/schemas/discord.schema.yml diff --git a/examples/schemas/hubspot_schema.yml b/examples/schemas/hubspot.schema.yml similarity index 100% rename from examples/schemas/hubspot_schema.yml rename to examples/schemas/hubspot.schema.yml diff --git a/examples/schemas/inferred_demo_schema.yml b/examples/schemas/inferred_demo.schema.yml similarity index 100% rename from examples/schemas/inferred_demo_schema.yml rename to examples/schemas/inferred_demo.schema.yml diff --git a/examples/schemas/rasa_schema.yml b/examples/schemas/rasa.schema.yml similarity index 100% rename from examples/schemas/rasa_schema.yml rename to examples/schemas/rasa.schema.yml diff --git a/examples/sources/rasa_tracker_store.py b/examples/sources/rasa_tracker_store.py index eae691fb01..45f6c468d8 100644 --- a/examples/sources/rasa_tracker_store.py +++ b/examples/sources/rasa_tracker_store.py @@ -1,5 +1,5 @@ from typing import Iterator -from dlt.common.sources import with_table_name +from dlt.common.source import with_table_name from dlt.common.typing import DictStrAny from dlt.common.time import timestamp_within diff --git a/examples/sources/singer_tap.py b/examples/sources/singer_tap.py index d23402e807..c04b94bb00 100644 --- a/examples/sources/singer_tap.py +++ b/examples/sources/singer_tap.py @@ -4,7 +4,7 @@ from dlt.common import json from dlt.common.runners.venv import Venv -from dlt.common.sources import with_table_name +from dlt.common.source import with_table_name from dlt.common.typing import DictStrAny, StrAny, StrOrBytesPath from examples.sources.stdout import get_source as get_singer_pipe diff --git a/experiments/pipeline/__init__.py b/experiments/pipeline/__init__.py deleted file mode 100644 index 68a3156325..0000000000 --- a/experiments/pipeline/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from experiments.pipeline.pipeline import Pipeline - -pipeline = Pipeline() - -# def __getattr__(name): -# if name == 'y': -# return 3 -# raise AttributeError(f"module '{__name__}' has no attribute '{name}'") - diff --git a/experiments/pipeline/async_decorator.py b/experiments/pipeline/async_decorator.py deleted file mode 100644 index b8fd4a3997..0000000000 --- a/experiments/pipeline/async_decorator.py +++ /dev/null @@ -1,585 +0,0 @@ -import asyncio -from collections import abc -from copy import deepcopy -from functools import wraps -from itertools import chain - -import inspect -import itertools -import os -import sys -from typing import Any, Coroutine, Dict, Iterator, List, NamedTuple, Sequence, cast - -from dlt.common.schema.schema import Schema -from dlt.common.schema.typing import TColumnSchema, TTableSchema, TTableSchemaColumns -from dlt.common.schema.utils import new_table -from dlt.common.sources import with_retry, with_table_name, get_table_name - -# from examples.sources.rasa_tracker_store - - -_meta = {} - -_i_schema: Schema = None -_i_info = None - -#abc.Iterator - - -class TableMetadataMixin: - def __init__(self, table: TTableSchema, schema: Schema = None): - self._table = table - self._schema = schema - self._table_name = table["name"] - self.__name__ = self._table_name - - @property - def table_schema(self): - # TODO: returns unified table schema by merging _schema and _table with table taking precedence - return self._table - - -class TableIterable(abc.Iterable, TableMetadataMixin): - def __init__(self, i, table, schema = None): - self._data = i - super().__init__(table, schema) - - def __iter__(self): - # TODO: this should resolve the _data like we do in the extract method: all awaitables and deferred items are resolved - # possibly in parallel. - if isinstance(self._data, abc.Iterator): - return TableIterator(self._data, self._table, self._schema) - return TableIterator(iter(self._data), self._table, self._schema) - - -class TableIterator(abc.Iterator, TableMetadataMixin): - def __init__(self, i, table, schema = None): - self.i = i - super().__init__(table, schema) - - def __next__(self): - # export metadata to global variable so it can be read by extractor - # TODO: remove this hack if possible - global _i_info - _i_info = cast(self, TableMetadataMixin) - - return next(self.i) - - def __iter__(self): - return self - - -class TableGenerator(abc.Generator, TableMetadataMixin): - def __init__(self, g, table, schema = None): - self.g = g - super().__init__(table, schema) - - def send(self, value): - return self.g.send(value) - - def throw(self, typ, val=None, tb=None): - return self.g.throw(typ, val, tb) - - -class SourceList(abc.Sequence): - def __init__(self, s, schema): - self.s: abc.Sequence = s - self.schema = schema - - # Sized - def __len__(self) -> int: - return self.s.__len__() - - # Iterator - def __next__(self): - return next(self.s) - - def __iter__(self): - return self - - # Container - def __contains__(self, value: object) -> bool: - return self.s.__contains__(value) - - # Reversible - def __reversed__(self): - return self.s.__reversed__() - - # Sequence - def __getitem__(self, index): - return self.s.__getitem__(index) - - def index(self, value: Any, start: int = ..., stop: int = ...) -> int: - return self.s.index(value, start, stop) - - def count(self, value: Any) -> int: - return self.s.count(value) - -class SourceTable(NamedTuple): - table_name: str - data: Iterator[Any] - - -def source(schema=None): - """This is source decorator""" - def _dec(f: callable): - print(f"calling source on {f.__name__}") - global _i_schema - - __dlt_schema = Schema(f.__name__) if not schema else schema - sig = inspect.signature(f) - - @wraps(f) - def _wrap(*args, **kwargs): - global _i_schema - - inner_schema: Schema = None - # if "schema" in kwargs and isinstance(kwargs["schema"], Schema): - # inner_schema = kwargs["schema"] - # # remove if not in sig - # if "schema" not in sig.parameters: - # del kwargs["schema"] - - _i_schema = inner_schema or __dlt_schema - rv = f(*args, **kwargs) - - if not isinstance(rv, (abc.Iterator, abc.Iterable)) or isinstance(rv, (dict, str)): - raise ValueError(f"Expected iterator/iterable containing tables {type(rv)}") - - # assume that source contain iterator of TableIterable - tables = [] - for table in rv: - # if not isinstance(rv, abc.Iterator) or isinstance(rv, (dict, str): - if not isinstance(table, TableIterable): - raise ValueError(f"Please use @table or as_table: {type(table)}") - tables.append(table) - # iterator consumed - clear schema - _i_schema = None - # if hasattr(rv, "__name__"): - # s.a - # source with single table - # return SourceList([rv], _i_schema) - # elif isinstance(rv, abc.Sequence): - # # peek what is inside - # item = None if len(rv) == 0 else rv[1] - # # if this is list, iterator or empty - # if isinstance(item, (NoneType, TableMetadataMixin, abc.Iterator)): - # return SourceList(rv, _i_schema) - # else: - # return SourceList([rv], _i_schema) - # else: - # raise ValueError(f"Unsupported table type {type(rv)}") - - return tables - # if isinstance(rv, abc.Iterable) or inspect(rv, abc.Iterator): - # yield from rv - # else: - # yield rv - print(f.__doc__) - _wrap.__doc__ = f.__doc__ + """This is source decorator""" - return _wrap - - # if isinstance(obj, callable): - # return _wrap - # else: - # return obj - return _dec - - -def table(name = None, write_disposition = None, parent = None, columns: Sequence[TColumnSchema] = None, schema = None): - def _dec(f: callable): - global _i_schema - - if _i_schema and schema: - raise Exception("Do not set explicit schema for a table in source context") - - l_schema = schema or _i_schema - table = new_table(name or f.__name__, parent, write_disposition, columns) - print(f"calling TABLE on {f.__name__}: {l_schema}") - - # @wraps(f, updated=('__dict__','__doc__')) - def _wrap(*args, **kwargs): - rv = f(*args, **kwargs) - return TableIterable(rv, table, l_schema) - # assert _i_info == None - - # def _yield_inner() : - # global _i_info - # print(f"TABLE: setting _i_info on {f.__name__} {l_schema}") - # _i_info = (table, l_schema) - - # if isinstance(rv, abc.Sequence): - # yield rv - # # return TableIterator(iter(rv), _i_info) - # elif isinstance(rv, abc.Generator): - # # return TableGenerator(rv, _i_info) - # yield from rv - # else: - # yield from rv - # _i_info = None - # # must clean up in extract - # # assert _i_info == None - - # gen_inner = _yield_inner() - # # generator name is a table name - # gen_inner.__name__ = "__dlt_meta:" + "*" if callable(table["name"]) else table["name"] - # # return generator - # return gen_inner - - # _i_info = None - # yield from map(lambda i: with_table_name(i, id(rv)), rv) - - return _wrap - - return _dec - - -def as_table(obj, name, write_disposition = None, parent = None, columns = None): - global _i_schema - l_schema = _i_schema - - # for i, f in sys._current_frames(): - # print(i, f) - - # print(sys._current_frames()) - - # try: - # for d in range(0, 10): - # c_f = sys._getframe(d) - # print(c_f.f_code.co_varnames) - # print("------------") - # if "__dlt_schema" in c_f.f_locals: - # l_schema = c_f.f_locals["__dlt_schema"] - # break - # except ValueError: - # # stack too deep - # pass - - # def inner(): - # # global _i_info - - # # assert _i_info == None - # print(f"AS_TABLE: setting _i_info on {name} {l_schema}") - # table = new_table(name, parent, write_disposition, columns) - # _i_info = (table, l_schema) - # if isinstance(obj, abc.Sequence): - # return TableIterator(iter(obj), _i_info) - # elif isinstance(obj, abc.Generator): - # return TableGenerator(obj, _i_info) - # else: - # return TableIterator(obj, _i_info) - # # if isinstance(obj, abc.Sequence): - # # yield obj - # # else: - # # yield from obj - # # _i_info = None - - table = new_table(name, parent, write_disposition, columns) - print(f"calling AS TABLE on {name}: {l_schema}") - return TableIterable(obj, table, l_schema) - # def _yield_inner() : - # global _i_info - # print(f"AS_TABLE: setting _i_info on {name} {l_schema}") - # _i_info = (table, l_schema) - - # if isinstance(obj, abc.Sequence): - # yield obj - # # return TableIterator(iter(obj), _i_info) - # elif isinstance(obj, abc.Generator): - # # return TableGenerator(obj, _i_info) - # yield from obj - # else: - # yield from obj - - # # must clean up in extract - # # assert _i_info == None - - # gen_inner = _yield_inner() - # # generator name is a table name - # gen_inner.__name__ = "__dlt_meta:" + "*" if callable(table["name"]) else table["name"] - # # return generator - # return gen_inner - - # return inner() - -# def async_table(write_disposition = None, parent = None, columns = None): - -# def _dec(f: callable): - -# def _wrap(*args, **kwargs): -# global _i_info - -# l_info = new_table(f.__name__, parent, write_disposition, columns) -# rv = f(*args, **kwargs) - -# for i in rv: -# # assert _i_info == None -# # print("set info") -# _i_info = l_info -# # print(f"what: {i}") -# yield i -# _i_info = None -# # print("reset info") - -# # else: -# # yield from rv -# # yield from map(lambda i: with_table_name(i, id(rv)), rv) - -# return _wrap - -# return _dec - - -# takes name from decorated function -@source() -def spotify(api_key=None): - """This is spotify source with several tables""" - - # takes name from decorated function - @table(write_disposition="append") - def numbers(): - return [1, 2, 3, 4] - - @table(write_disposition="replace") - def songs(library): - - # https://github.com/leshchenko1979/reretry - async def _i(id): - await asyncio.sleep(0.5) - # raise Exception("song cannot be taken") - return {f"song{id}": library} - - for i in range(3): - yield _i(i) - - @table(write_disposition="replace") - def albums(library): - - async def _i(id): - await asyncio.sleep(0.5) - return {f"album_{id}": library} - - - for i in ["X", "Y"]: - yield _i(i) - - @table(write_disposition="append") - def history(): - """This is your song history""" - print("HISTORY yield") - yield {"name": "dupa"} - - print("spotify returns list") - return ( - history(), - numbers(), - songs("lib_1"), - as_table(["lib_song"], name="library"), - albums("lib_2") - ) - - -@source() -def annotations(): - """Ad hoc annotation source""" - yield as_table(["ann1", "ann2", "ann3"], "annotate", write_disposition="replace") - - -# this table exists out of source context and will attach itself to the current default schema in the pipeline -@table(write_disposition="merge", parent="songs_authors") -def songs__copies(song, num): - return [{"song": song, "copy": True}] * num - - -event_column_template: List[TColumnSchema] = [{ - "name": "timestamp", - "data_type": "timestamp", - "nullable": False, - "partition": True, - "sorted": True - } -] - -# this demonstrates the content based naming of the tables for stream based sources -# same rule that applies to `name` could apply to `write_disposition` and `columns` -@table(name=lambda i: "event_" + i["event"], write_disposition="append", columns=event_column_template) -def events(): - from examples.sources.jsonl import get_source as read_jsonl - - sources = [ - read_jsonl(file) for file in os.scandir("examples/data/rasa_trackers") - ] - for i in chain(*sources): - yield { - "sender_id": i["sender_id"], - "timestamp": i["timestamp"], - "event": i["event"] - } - yield i - - -# another standalone source -authors = as_table(["authr"], "songs_authors") - - -# def source_with_schema_discovery(credentials, sheet_id, tab_id): - -# # discover the schema from actual API -# schema: Schema = get_schema_from_sheets(credentials, sheet_id) - -# # now declare the source -# @source(schema=schema) -# @table(name=schema.schema_name, write_disposition="replace") -# def sheet(): -# from examples.sources.google_sheets import get_source - -# yield from get_source(credentials, sheet_id, tab_id) - -# return sheet() - - -class Pipeline: - def __init__(self, parallelism = 2, default_schema: Schema = None) -> None: - self.sem = asyncio.Semaphore(parallelism) - self.schemas: Dict[str, Schema] = {} - self.default_schema_name: str = "" - if default_schema: - self.default_schema_name = default_schema.name - self.schemas[default_schema.name] = default_schema - - async def extract(self, items, schema: Schema = None): - # global _i_info - - l_info = None - if isinstance(items, TableIterable): - l_info = (items._table, items._schema) - print(f"extracting table with name {getattr(items, '__name__', None)} {l_info}") - - # if id(i) in meta: - # print(meta[id(i)]) - - def _p_i(item, what): - if l_info: - info_schema: Schema = l_info[1] - if info_schema: - # if already in pipeline - use the pipeline one - info_schema = self.schemas.get(info_schema.name) or info_schema - # if explicit - use explicit - eff_schema = schema or info_schema - if eff_schema is None: - # create default schema when needed - eff_schema = self.schemas.get(self.default_schema_name) or Schema(self.default_schema_name) - if eff_schema is not None: - table: TTableSchema = l_info[0] - # normalize table name - if callable(table["name"]): - table_name = table["name"](item) - else: - table_name = eff_schema.normalize_table_name(table["name"]) - - if table_name not in eff_schema._schema_tables: - table = deepcopy(table) - table["name"] = table_name - # TODO: normalize all other names - eff_schema.update_schema(table) - # TODO: l_info may contain type hints etc. - self.schemas[eff_schema.name] = eff_schema - if len(self.schemas) == 1: - self.default_schema_name = eff_schema.name - - print(f"{item} of {what} has HINT and will be written as {eff_schema.name}:{table_name}") - else: - eff_schema = self.schemas.get(self.default_schema_name) or Schema(self.default_schema_name) - print(f"{item} of {what} No HINT and will be written as {eff_schema.name}:table") - - # l_info = _i_info - if isinstance(items, TableIterable): - items = iter(items._data) - if isinstance(items, (abc.Sequence)): - items = iter(items) - # for item in items: - # _p_i(item, "list_item") - if inspect.isasyncgen(items): - raise NotImplemented() - else: - # context is set immediately - item = next(items, None) - if item is None: - return - - global _i_info - - if l_info is None and isinstance(_i_info, TableMetadataMixin): - l_info = (_i_info._table, _i_info._schema) - # l_info = _i_info - # _i_info = None - if inspect.iscoroutine(item) or isinstance(item, Coroutine): - async def _await(a_i): - async with self.sem: - # print("enter await") - a_itm = await a_i - _p_i(a_itm, "awaitable") - - items = await asyncio.gather( - asyncio.ensure_future(_await(item)), *(asyncio.ensure_future(_await(ii)) for ii in items) - ) - else: - _p_i(item, "iterator") - list(map(_p_i, items, itertools.repeat("iterator"))) - - # print("reset info") - # _i_info = None - # assert _i_info is None - - def extract_all(self, sources, schema: Schema = None): - loop = asyncio.get_event_loop() - loop.run_until_complete(asyncio.gather(*[self.extract(i, schema=schema) for i in sources])) - # loop.close() - - -default_schema = Schema("") - -print("s iter of iters") -s = spotify(api_key="7") -for items in s: - print(items.__name__) -s[0] = map(lambda d: {**d, **{"lambda": True}} , s[0]) -print("s2 iter of iters") -s2 = annotations() - -# for x in s[0]: -# print(x) -# exit() - -# print(list(s2)) -# exit(0) - -# mod albums - -def mapper(a): - a["mapper"] = True - return a - -# https://asyncstdlib.readthedocs.io/en/latest/# -# s[3] = map(mapper, s[3]) - - -# Pipeline().extract_all([s2]) -p = Pipeline(default_schema=Schema("default")) -chained = chain(s, s2, [authors], [songs__copies("abba", 4)], [["raw", "raw"]]) -# for items in chained: -# print(f"{type(items)}: {getattr(items, '__name__', 'NONE')}") -p.extract_all(chained, schema=None) -# p.extract_all([events()], schema=Schema("events")) -p.extract_all([["nein"] * 5]) -p.extract_all([as_table([], name="EMPTY")]) - - -for schema in p.schemas.values(): - print(schema.to_pretty_yaml(remove_defaults=True)) - -print(p.default_schema_name) -# for i in s: -# await extract(i) -# for i in s: -# extract(chain(*i, iter([1, "zeta"]))) \ No newline at end of file diff --git a/experiments/pipeline/configuration.py b/experiments/pipeline/configuration.py deleted file mode 100644 index f584b0e159..0000000000 --- a/experiments/pipeline/configuration.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import Any, Type - -from dlt.common.typing import DictStrAny, TAny -from dlt.common.configuration.utils import make_configuration - - -def get_config(spec: Type[TAny], key: str = None, namespace: str = None, initial_values: Any = None, accept_partial: bool = False) -> Type[TAny]: - # TODO: implement key and namespace - return make_configuration(spec, spec, initial_values=initial_values, accept_partial=accept_partial) - diff --git a/experiments/pipeline/credentials.py b/experiments/pipeline/credentials.py deleted file mode 100644 index 1513e06f42..0000000000 --- a/experiments/pipeline/credentials.py +++ /dev/null @@ -1,22 +0,0 @@ - -from typing import Any, Sequence, Type - -# gets credentials in namespace (ie pipeline name), grouped under key with spec -# spec can be a class, TypedDict or dataclass. overwrites initial_values -def get_credentials(spec: Type[Any] = None, key: str = None, namespace: str = None, initial_values: Any = None) -> Any: - # will use registered credential providers for all values in spec or return all values under key - pass - - -def get_config(spec: Type[Any], key: str = None, namespace: str = None, initial_values: Any = None) -> Any: - # uses config providers (env, .dlt/config.toml) - # in case of TSecretValues fallbacks to using credential providers - pass - - -class ConfigProvider: - def get(name: str) -> Any: - pass - - def list(prefix: str = None) -> Sequence[str]: - pass diff --git a/experiments/pipeline/examples/README.md b/experiments/pipeline/examples/README.md new file mode 100644 index 0000000000..9588996df6 --- /dev/null +++ b/experiments/pipeline/examples/README.md @@ -0,0 +1,11 @@ +## Finished documents + +1. [general_usage.md](general_usage.md) +2. [project_structure.md](project_structure.md) & `dlt init` CLI +3. [create_pipeline.md](create_pipeline.md) +4. [secrets_and_config.md](secrets_and_config.md) +5. [working_with_schemas.md](working_with_schemas.md) + +## In progress + +I'll be writing advanced stuff later (ie. state and multi step pipelines etc.) diff --git a/experiments/pipeline/examples/create_pipeline.md b/experiments/pipeline/examples/create_pipeline.md new file mode 100644 index 0000000000..61af3b6e84 --- /dev/null +++ b/experiments/pipeline/examples/create_pipeline.md @@ -0,0 +1,339 @@ + +## Example for the simplest ad hoc pipeline without any structure +It is still possible to create "intuitive" pipeline without much knowledge except how to import dlt engine and how to import the destination. + +No decorators and secret files, configurations are necessary. We should probably not teach that but I want this kind of super basic and brute code to still work + + +```python +import requests +import dlt +from dlt.destinations import bigquery + +resp = requests.get( + "https://taktile.com/api/v2/logs?from_log_id=1", + headers={"Authorization": "98217398ahskj92982173"}) +resp.raise_for_status() +data = resp.json() + +# if destination or name are not provided, an exception will raise that explains +# 1. where and why to put the name of the table +# 2. how to import the destination and how to configure it with credentials in a proper way +# nevertheless the user decided to pass credentials directly +dlt.run(data["result"], name="logs", destination=bigquery(Service.credentials_from_file("service.json"))) +``` + +## Source extractor function the preferred way +General guidelines: +1. the source extractor is a function decorated with `@dlt.source`. that function yields or returns a list of resources. **it should not access the data itself**. see the example below +2. resources are generator functions that always **yield** data (I think I will enforce that by raising exception). Access to external endpoints, databases etc. should happen from that generator function. Generator functions may be decorated with `@dlt.resource` to provide alternative names, write disposition etc. +3. resource generator functions can be OFC parametrized and resources may be created dynamically +4. the resource generator function may yield a single dict or list of dicts +5. like any other iterator, the @dlt.source and @dlt.resource **can be iterated and thus extracted and loaded only once**, see example below. + +> my dilemma here is if I should allow to access data directly in the source function ie. to discover schema or get some configuration for the resources from some endpoint. it is very easy to avoid that but for the non-programmers it will not be intuitive. + +## Example for endpoint returning only one resource: + +```python +import requests +import dlt + +# the `dlt.source` tell the library that the decorated function is a source +# it will use function name `taktile_data` to name the source and the generated schema by default +# in general `@source` should **return** a list of resources or list of generators (function that yield data) +# @source may also **yield** resources or generators - if yielding is more convenient +# if @source returns or yields data - this will generate exception with a proper explanation. dlt user can always load the data directly without any decorators like in the previous example! +@dlt.source +def taktile_data(initial_log_id, taktile_api_key): + + # the `dlt.resource` tells the `dlt.source` that the function defines a resource + # will use function name `logs` as resource/table name by default + # the function should **yield** the data items one by one or **yield** a list. + # here the decorator is optional: there are no parameters to `dlt.resource` + @dlt.resource + def logs(): + resp = requests.get( + "https://taktile.com/api/v2/logs?from_log_id=%i" % initial_log_id, + headers={"Authorization": taktile_api_key}) + resp.raise_for_status() + # option 1: yield the whole list + yield resp.json()["result"] + # or -> this is useful if you deal with a stream of data and for that you need an API that supports that, for example you could yield lists containing paginated results + for item in resp.json()["result"]: + yield item + + # as mentioned we return a resource or a list of resources + return logs + # this will also work + return logs() + + +# now load the data +taktile_data(1).run(destination=bigquery) +# this below also works +# dlt.run(source=taktile_data(1), destination=bigquery) + +# now to illustrate that each source can be loaded only once, if you run this below +data = taktile_data(1) +data.run(destination=bigquery) # works as expected +data.run(destination=bigquery) # generates empty load package as the data in the iterator is exhausted... maybe I should raise exception instead? + +``` + +**Remarks:** + +1. the **@dlt.resource** let's you define the table schema hints: `name`, `write_disposition`, `parent`, `columns` +2. the **@dlt.source** let's you define global schema props: `name` (which is also source name), `schema` which is Schema object if explicit schema is provided `nesting` to set nesting level etc. (I do not have a signature now - I'm still working on it) +3. decorators can also be used as functions ie in case of dlt.resource and `lazy_function` (see one page below) +```python +endpoints = ["songs", "playlist", "albums"] +# return list of resourced +return [dlt.resource(lazy_function(endpoint, name=endpoint) for endpoint in endpoints)] + +``` + +**What if we remove logs() function and get data in source body** + +Yeah definitely possible. Just replace `@source` with `@resource` decorator and remove the function + +```python +@dlt.resource(name="logs", write_disposition="append") +def taktile_data(initial_log_id, taktile_api_key): + + # yes, this will also work but data will be obtained immediately when taktile_data() is called. + resp = requests.get( + "https://taktile.com/api/v2/logs?from_log_id=%i" % initial_log_id, + headers={"Authorization": taktile_api_key}) + resp.raise_for_status() + for item in resp.json()["result"]: + yield item + +# this will load the resource into default schema. see `general_usage.md) +dlt.run(source=taktile_data(1), destination=bigquery) + +``` + +**The power of decorators** + +With decorators dlt can inspect and modify the code being decorated. +1. it knows what are the sources and resources without running them +2. it knows input arguments so it knows the config values and secret values (see `secrets_and_config`). with those we can generate deployments automatically +3. it can inject config and secret values automatically +4. it wraps the functions into objects that provide additional functionalities +- sources and resources are iterators so you can write +```python +items = list(source()) + +for item in source()["logs"]: + ... +``` +- you can select which resources to load with `source().select(*names)` +- you can add mappings and filters to resources + +## The power of yielding: The preferred way to write resources + +The Python function that yields is not a function but magical object that `dlt` can control: +1. it is not executed when you call it! the call just creates a generator (see below). in the example above `taktile_data(1)` will not execute the code inside, it will just return an object composed of function code and input parameters. dlt has control over the object and can execute the code later. this is called `lazy execution` +2. i can control when and how much of the code is executed. the function that yields typically looks like that + +```python +def lazy_function(endpoint_name): + # INIT - this will be executed only once when DLT wants! + get_configuration() + from_item = dlt.state.get("last_item", 0) + l = get_item_list_from_api(api_key, endpoint_name) + + # ITERATOR - this will be executed many times also when DLT wants more data! + for item in l: + yield requests.get(url, api_key, "%s?id=%s" % (endpoint_name, item["id"])).json() + # CLEANUP + # this will be executed only once after the last item was yielded! + dlt.state["last_item"] = item["id"] +``` + +3. dlt will execute this generator in extractor. the whole execution is atomic (including writing to state). if anything fails with exception the whole extract function fails. +4. the execution can be parallelized by using a decorator or a simple modifier function ie: +```python +for item in l: + yield deferred(requests.get(url, api_key, "%s?id=%s" % (endpoint_name, item["id"])).json()) +``` + +## Python data transformations + +```python +from dlt.secrets import anonymize + +def transform_user(user_data): + # anonymize creates nice deterministic hash for any hashable data type + user_data["user_id"] = anonymize(user_data["user_id"]) + user_data["user_email"] = anonymize(user_data["user_email"]) + return user_data + +# usage: can be applied in the source +@dlt.source +def hubspot(...): + ... + + @dlt.resource(write_disposition="replace") + def users(): + ... + users = requests.get(...) + # option 1: just map and yield from mapping + users = map(transform_user, users) + ... + yield users, deals, customers + + # option 2: return resource with chained transformation + return users.map(transform_user) + +# option 3: user of the pipeline determines if s/he wants the anonymized data or not and does it in pipeline script. so the source may offer also transformations that are easily used +hubspot(...)["users"].map(transform_user) +hubspot.run(...) + +``` + +## Multiple resources and resource selection +The source extraction function may contain multiple resources. The resources can be defined as multiple resource functions or created dynamically ie. with parametrized generators. +The user of the pipeline can check what resources are available and select the resources to load. + + +**each resource has a a separate resource function** +```python +import requests +import dlt + +@dlt.source +def hubspot(...): + + @dlt.resource(write_disposition="replace") + def users(): + # calls to API happens here + ... + yield users + + @dlt.resource(write_disposition="append") + def transactions(): + ... + yield transactions + + # return a list of resources + return users, transactions + +# load all resources +taktile_data(1).run(destination=bigquery) +# load only decisions +taktile_data(1).select("decisions").run(....) +``` + +**resources are created dynamically** +Here we implement a single parametrized function that **yields** data and we call it repeatedly. Mind that the function body won't be executed immediately, only later when generator is consumed in extract stage. + +```python + +@dlt.source +def spotify(): + + endpoints = ["songs", "playlists", "albums"] + + def get_resource(endpoint): + # here we yield the whole response + yield requests.get(url + "/" + endpoint).json() + + # here we yield resources because this produces cleaner code + for endpoint in endpoints: + # calling get_resource creates generator, the actual code of the function will be executed in extractor + yield dlt.resource(get_resource(endpoint), name=endpoint) + +``` + +**resources are created dynamically from a single document** +Here we have a list of huge documents and we want to split it into several tables. We do not want to rely on `dlt` normalize stage to do it for us for some reason... + +This also provides an example of why getting data in the source function (and not within the resource function) is discouraged. + +```python + +@dlt.source +def spotify(): + + # get the data in source body and the simply return the resources + # this is discouraged because data access + list_of_huge_docs = requests.get(...) + + return dlt.resource(list_of_huge_docs["songs"], name="songs"), dlt.resource(list_of_huge_docs["playlists"], name="songs") + +# the call to get the resource list happens outside the `dlt` pipeline, this means that if there's +# exception in `list_of_huge_docs = requests.get(...)` I cannot handle or log it (or send slack message) +# user must do it himself or the script will be simply killed. not so much problem during development +# but may be a problem after deployment. +spotify().run(...) +``` + +How to prevent that: +```python +@dlt.source +def spotify(): + + list_of_huge_docs = None + + def get_data(name): + # regarding intuitiveness and cleanliness of the code this is a hack/trickery IMO + # this will also have consequences if execution is parallelized + nonlocal list_of_huge_docs + docs = list_of_huge_docs or list_of_huge_docs = requests.get(...) + yield docs[name] + + return dlt.resource(get_data("songs"), name="songs"), dlt.resource(get_data("playlists"), name="songs") +``` + +The other way to prevent that (see also `multistep_pipelines_and_dependent_resources.md`) + +```python +@dlt.source +def spotify(): + + @dlt.resource + def get_huge_doc(name): + yield requests.get(...) + + # make songs and playlists to be dependent on get_huge_doc + @dlt.resource(depends_on=huge_doc) + def songs(huge_doc): + yield huge_doc["songs"] + + @dlt.resource(depends_on=huge_doc) + def playlists(huge_doc): + yield huge_doc["playlists"] + + # as you can see the get_huge_doc is not even returned, nevertheless it will be evaluated (only once) + # the huge doc will not be extracted and loaded + return songs, playlists +``` + +> I could also implement lazy evaluation of the @dlt.source function. this is a lot of trickery in the code but definitely possible. there are consequences though: if someone requests lists of resources or the initial schema in the pipeline script before `run` method the function body will be evaluated. It is really hard to make intuitive code to work properly. + +## Pipeline with multiple sources or with same source called twice + +Here our source is parametrized or we have several sources to be extracted. This is more or less Ty's twitter case. + +```python +@dlt.source +def mongo(from_id, to_id, credentials): + ... + +@dlt.source +def labels(): + ... + + +# option 1: at some point I may parallelize execution of sources if called this way +dlt.run(source=[mongo(0, 100000), mongo(100001, 200000), labels()], destination=bigquery) + +# option 2: be explicit (this has consequences: read the `run` method in `general_usage`) +p = dlt.pipeline(destination=bigquery) +p.extract(mongo(0, 100000)) +p.extract(mongo(100001, 200000)) +p.extract(labels()) +p.normalize() +p.load() diff --git a/experiments/pipeline/examples/general_usage.md b/experiments/pipeline/examples/general_usage.md new file mode 100644 index 0000000000..c7b214375b --- /dev/null +++ b/experiments/pipeline/examples/general_usage.md @@ -0,0 +1,197 @@ +## importing dlt +Basic `dlt` functionalities are imported with `import dlt`. Those functionalities are: +1. ability to run the pipeline (which means extract->normalize->load for particular source(s) and destination) with `dlt.run` +2. ability to configure the pipeline ie. provide alternative pipeline name, working directory, folders to import/export schema and various flags: `dlt.pipeline` +3. ability to decorate sources (`dlt.source`) and resources (`dlt.resources`) +4. ability to access secrets `dlt.secrets` and config values `dlt.config` + +## importing destinations +We support a few built in destinations which may be imported as follows +```python +import dlt +from dlt.destinations import bigquery +from dlt.destinations import redshift +``` + +The imported modules may be directly passed to `run` or `pipeline` method. They can be also called to provide credentials and other settings explicitly (discouraged) ie. `bigquery(Service.credentials_from_file("service.json"))` will bind the credentials to the module. + +Destinations require `extras` to be installed, if that is not the case, an exception with user friendly message will tell how to do that. + +## importing sources +We do not have any structure for the source repository so IDK. For `create pipeline` workflow the source is in the same script as `run` method so the problem does not exist now (?). + +In principle, however, the importable sources are extractor functions so they are imported like any other function. + + +## default and explicitly configured pipelines +When the `dlt` is imported a default pipeline is automatically created. That pipeline is configured via configuration providers (ie. `config.toml` or env variables - see [secrets_and_config.md](secrets_and_config.md)). If no configuration is present, default values will be used. + +1. the name of the pipeline, the name of default schema (if not overridden by the source extractor function) and the default dataset (in destination) are set to **current module name** which in 99% of cases is the name of executing python script +2. the working directory of the pipeline will be **OS temporary folder/pipeline name** +3. the logging level will be **INFO** +4. all other configuration options won't be set or will have default values. + +Pipeline can be explicitly created and configured via `dlt.pipeline()` that returns `Pipeline` object. All parameters are optional. If no parameter is provided then default pipeline is returned. Here's a list of options. All the options are configurable. +1. pipeline_name - default as above +2. working_dir - default as above +3. pipeline_secret - for deterministic hashing - default is random number +4. destination - the imported destination module or module name (we accept strings so they can be configured) - default is None +5. import_schema_path - default is None +6. export_schema_path - default is None +7. full_refresh - if set to True all the pipeline working dir and all datasets will be dropped with each run +8. ...any other popular option... give me ideas. maybe `dataset_name`? + +> **Achtung** as per `secrets_and_config.md` the options passed in the code have **lower priority** than any config settings. Example: the pipeline name passed to `dlt.pipeline()` will be overwritten if `pipeline_name` is present in `config.toml` or `PIPELINE_NAME` is in config variables. + + +> It is possible to have several pipelines in a single script if many pipelines are configured via `dlt.pipeline()`. I think we do not want to train people on that so I will skip the topic. + +## the default schema and the default data set name +`dlt` follows the following rules when auto-generating schemas and naming the dataset to which the data will be loaded. + +**schemas are identified by schema names** + +**default schema** is the first schema that is provided or created within the pipeline. First schema comes in the following ways: +1. From the first extracted `@dlt.source` ie. if you `dlt.run(source=sportify(), ...)` and `spotify` source has schema with name `spotify` attached, it will be used as default schema. +2. it will be created from scratch if you extract a `@dlt.resource` or an iterator ie. list (example: `dlt.run(source=["a", "b", "c"], ...)`) and its name is the pipeline name or generator function name if generator is extracted. (I'm trying to be smart with automatic naming) +3. it is explicitly passed with the `schema` parameter to `run` or `extract` methods - this forces all the sources regardless of the form to place their tables in that schema. + +The **default schema** comes into play when we extract data as in point (2) - without schema information. in that case the default schema is used to attach tables coming from that data + +The pipeline works with multiple schemas. If you extract another source or provide schema explicitly, that schema becomes part of pipeline. Example +```python + +p = dlt.pipeline(dataset="spotify_data_1") +p.extract(source=spotify("me")) # gets schema "spotify" from spotify source, that schema becomes default schema +p.extract(source=echonest("me").select("mel")) # get schema "echonest", all tables belonging to resource "mel" will be placed in that schema +p.extract(source=[label1, label2, label3], name="labels") # will use default schema "spotitfy" for table "labels" +``` + +> learn more on how to work with schemas both via files and programmatically in [working_with_schemas.md](working_with_schemas.md) + +**dataset name** +`dlt` will load data to a specified dataset in the destination. The dataset in case of bigquery is a native dataset, in case of redshift is a native database schema. **One dataset can handle only one schema**. + +There is a default dataset name which is the same as pipeline name. The dataset name can also be explicitly provided into `dlt.pipeline` `dlt.run` and `Pipeline::load` methods. + +In case **there's only default schema** in the pipeline, the data will be loaded into dataset name. Example: `dlt.run(source=spotify("me"), dataset="spotify_data_1")` will load data into dataset `spotify_data_1`) + +In case **there are more schemas in the pipeline**, the data will be loaded into dataset with name `{dataset_name}` for default schema and `{dataset_name}_{schema_name}` for all the other schemas. For the example above: +1. `spotify` tables and `labels` will load into `spotify_data_1` +2. `mel` resource will load into `spotify_data_1_echonest` + + +## pipeline working directory and state +the working directory of the pipeline will be **OS temporary folder/pipeline name** + +Another fundamental concept is the pipeline working directory. This directory keeps the following information: +1. the extracted data and the load packages with jobs created by normalize +2. the current schemas with all the recent updates +3. the pipeline and source state files. + +**Pipeline working directory should be preserved between the runs - if possible** + +If the working directory is not preserved: +1. the auto-evolved schema is reset to the initial one. the schema evolution is deterministic so it should not be a problem - just a time wasted to compare schemas with each run +2. if load package is not fully loaded and erased then the destination holds partially loaded and not committed `load_id` +3. the sources that need source state will not load incrementally. + +This is the situation right now. We could restore working directory from the destination (both schemas and state). Entirely doable (for some destinations) but can't be done right now. + +## running pipelines and `dlt.run` + `@source().run` functions +`dlt.run` + `@source().run` are shortcuts to `Pipeline::run` method on default or last configured (with `dlt.pipeline`) `Pipeline` object. Please refer to [create_pipeline.md](create_pipeline.md) for examples. + +The function takes the following parameters +1. source - required - the data to be loaded into destination: a `@dlt.source` or a list of those, a `@dlt.resource` or a list of those, an iterator/generator function or a list of those or iterable (ie. a list) holding something else that iterators. +2. destination +3. dataset name +4. table_name, write_disposition etc. - only when data is: a single resource, an iterator (ie. generator function) or iterable (ie. list) +5. schema - a `Schema` instance to be used instead of schema provided by the source or the default schema + +The `run` function works as follows. +1. if there's any pending data to be normalized or loaded, this is done first. +2. only when successful more data is extracted +3. only when successful newly extracted data is normalized and loaded. + +extract / normalize / load are atomic. the `run` is as close to be atomic as possible. + +the `run` and `load` return information on loaded packages: to which datasets, list of jobs etc. let me think what should be the content + +> `load` is atomic if SQL transformations ie in `dbt` and all the SQL queries take into account only committed `load_ids`. It is certainly possible - we did in for RASA but requires some work... Maybe we implement a fully atomic staging at some point in the loader. + + +## the `Pipeline` object +There are many ways to create or get current pipeline object. +```python + +# create and get default pipeline +p1 = dlt.pipeline() +# create explicitly configured pipeline +p2 = dlt.pipeline(name="pipe", destination=bigquery) +# get recently created pipeline +assert dlt.pipeline() is p2 +# load data with recently created pipeline +assert dlt.run(source=taktile_data()) is p2 +assert taktile_data().run() is p2 + +``` + +The `Pipeline` object provides following functionalities: +1. `run`, `extract`, `normalize` and `load` methods +2. a `pipeline.schema` dictionary-like object to enumerate and get the schemas in pipeline +3. schema get with `pipeline.schemas[name]` is a live object: any modification to it is automatically applied to the pipeline with the next `run`, `load` etc. see [working_with_schemas.md](working_with_schemas.md) +4. it returns `sql_client` and `native_client` to get direct access to the destination (if destination supports SQL - currently all of them do) +5. it has several methods to inspect the pipeline state and I think those should be exposed via `dlt pipeline` CLI + +for example: +- list the extracted files if any +- list the load packages ready to load +- list the failed jobs in package +- show info on destination: what are the datasets, the current load_id, the current schema etc. + + +## Examples + +Loads data from `taktile_data` source function into bigquery. All the credentials and configs are taken from the config and secret providers. + +Script was run with `python taktile.py` + +```python +from my_taktile_source import taktile_data +from dlt.destinations import bigquery + +# the `run` command below will create default pipeline and use it to load data +# I only want logs from the resources present in taktile_data +taktile_data.select("logs").run(destination=bigquery) + +# alternative +dlt.run(source=taktile_data.select("logs")) +``` + +Explicitly configure schema before the use +```python +import dlt +from dlt.destinations import bigquery + +@dlt.source +def data(api_key): + ... + + +dlt.pipeline(name="pipe", destination=bigquery, dataset="extract_1") +# use dlt secrets directly to get api key +# no parameters needed to run - we configured destination and dataset already +data(dlt.secrets["api_key"]).run() +``` + +## command line interface +I need concept for that. see [project_structure.md](project_structure.md) + +## logging +I need your input for user friendly logging. What should we log? What is important to see? + +## pipeline runtime setup + +1. logging - creates logger with the name `dlt` which can be disabled the python way if someone does not like it. (contrary to `dbt` logger which is uncontrollable mess) +2. signals - signals required to gracefully stop pipeline with CTRL-C, in docker, kubernetes, cron are handled. signals are not handled if `dlt` runs as part of `streamlit` app or a notebook. +3. unhandled exceptions - we do not catch unhandled exceptions... but we may do that if run in standalone script. \ No newline at end of file diff --git a/experiments/pipeline/examples/last_value_with_state.md b/experiments/pipeline/examples/last_value_with_state.md new file mode 100644 index 0000000000..b90d629f72 --- /dev/null +++ b/experiments/pipeline/examples/last_value_with_state.md @@ -0,0 +1,26 @@ + +## With pipeline state and incremental load + + +from_log_id = dlt.state.get("from_log_id") or initial_log_id +```python +import requests +import dlt + +# it will use function name `taktile_data` to name the source and schema +@dlt.source +def taktile_data(initial_log_id, taktile_api_key): + from_log_id = dlt.state.get("from_log_id") or initial_log_id + resp = requests.get( + "https://taktile.com/api/v2/logs?from_log_id=%i" % initial_log_id, + headers={"Authorization": taktile_api_key}) + resp.raise_for_status() + data = resp.json() + + # write state before returning data + + # yes you can return a list of values and it will work + yield dlt.resource(data["result"], name="logs") + + +taktile_data(1).run(destination=bigquery) \ No newline at end of file diff --git a/dlt/extract/generator/__init__.py b/experiments/pipeline/examples/multistep_pipelines_and_dependent_resources.md similarity index 100% rename from dlt/extract/generator/__init__.py rename to experiments/pipeline/examples/multistep_pipelines_and_dependent_resources.md diff --git a/experiments/pipeline/examples/project_structure.md b/experiments/pipeline/examples/project_structure.md new file mode 100644 index 0000000000..0dd18097f1 --- /dev/null +++ b/experiments/pipeline/examples/project_structure.md @@ -0,0 +1,35 @@ +## Project structure for a create pipeline workflow + +Look into [project_structure](project_structure). It is a clone of template repository that we should have in our github. The files in the repository are parametrized with parameters of `dlt init` command. + +1. it contains `.dlt` folded with `config.toml` and `secrets.toml`. +2. we prefill those files with values corresponding to the destination +3. the requirements contain `python-dlt` and `requests` in `requirements.txt` +4. `.gitignore` for `secrets.toml` and `.env` (python virtual environment) +5. the pipeline script file `pipeline.py` containing template for a new source +6. `README.md` file with whatever content we need + + +## dlt init + +The prerequisites to run the command is to +1. create virtual environment +2. install `python-dlt` without extras + +> Question: any better ideas? I do not see anything simpler to go around. + +Proposed interface for the command: +`dlt init ` +Where `destination` must be one of our supported destination names: `bigquery` or `redshift` and source is alphanumeric string. + +Should be executed in an empty directory without `.git` or any other files. It will clone a template and create the project structure as above. The files in the project will be customized: + +1. `secrets.toml` will be prefilled with required credentials and secret values +2. `config.toml` will contain `pipeline_name` +3. the `pipeline.py` (1) will import the right destination (2) the source name will be changed to `_data` (3) the dataset name will be changed to `` etc. +4. `requirements.txt` will contain a proper dlt extras and requests library + +> Questions: +> 1. should we generate a working pipeline as a template (ie. with existing API) or a piece of code with instructions how to change it? +> 2. which features should we show in the template? parametrized source? providing api key and simple authentication? many resources? parametrized resources? configure export and import of schema yaml files? etc? +> 3. should we `pip install` the required extras ans requests when `dlt init` is run? diff --git a/experiments/pipeline/examples/project_structure/.dlt/config.toml b/experiments/pipeline/examples/project_structure/.dlt/config.toml new file mode 100644 index 0000000000..fd197222ac --- /dev/null +++ b/experiments/pipeline/examples/project_structure/.dlt/config.toml @@ -0,0 +1,2 @@ +pipeline_name="twitter" +# export_ diff --git a/experiments/pipeline/examples/project_structure/.dlt/secrets.toml b/experiments/pipeline/examples/project_structure/.dlt/secrets.toml new file mode 100644 index 0000000000..6844d188d5 --- /dev/null +++ b/experiments/pipeline/examples/project_structure/.dlt/secrets.toml @@ -0,0 +1,6 @@ +api_key="set me up" + +[destination.bigquery.credentials] +project_id="set me up" +private_key="set me up" +client_email="set me up" diff --git a/experiments/pipeline/examples/project_structure/.gitignore b/experiments/pipeline/examples/project_structure/.gitignore new file mode 100644 index 0000000000..b3b3ed2cb7 --- /dev/null +++ b/experiments/pipeline/examples/project_structure/.gitignore @@ -0,0 +1 @@ +secrets.toml \ No newline at end of file diff --git a/experiments/pipeline/examples/project_structure/README.md b/experiments/pipeline/examples/project_structure/README.md new file mode 100644 index 0000000000..7e17869bc1 --- /dev/null +++ b/experiments/pipeline/examples/project_structure/README.md @@ -0,0 +1,3 @@ +# How to customize and deploy this pipeline? + +Maybe the training syllabus goes here? \ No newline at end of file diff --git a/dlt/extract/generator/extractor.py b/experiments/pipeline/examples/project_structure/__init__.py similarity index 100% rename from dlt/extract/generator/extractor.py rename to experiments/pipeline/examples/project_structure/__init__.py diff --git a/experiments/pipeline/examples/project_structure/pipeline.py b/experiments/pipeline/examples/project_structure/pipeline.py new file mode 100644 index 0000000000..3b2bb61c9b --- /dev/null +++ b/experiments/pipeline/examples/project_structure/pipeline.py @@ -0,0 +1,37 @@ +import requests +import dlt +from dlt.destinations import bigquery + + +# explain `dlt.source` a little here and last_id and api_key parameters +@dlt.source +def twitter_data(last_id, api_key): + # example of Bearer Authentication + # create authorization headers + headers = { + "Authorization": f"Bearer {api_key}" + } + + # explain the `dlt.resource` and the default table naming, write disposition etc. + @dlt.resource + def example_data(): + # make a call to the endpoint with request library + resp = requests.get("https://example.com/data?last_id=%i" % last_id, headers=headers) + resp.raise_for_status() + # yield the data from the resource + data = resp.json() + # you may process the data here + # example transformation? + # return resource to be loaded into `example_data` table + # explain that data["items"] contains a list of items + yield data["items"] + + # return all the resources to be loaded + return example_data + +# configure the pipeline +dlt.pipeline(destination=bigquery, dataset="twitter") +# explain that api_key will be automatically loaded from secrets.toml or environment variable below +load_info = twitter_data(0).run() +# pretty print the information on data that was loaded +print(load_info) diff --git a/experiments/pipeline/examples/project_structure/requirements.txt b/experiments/pipeline/examples/project_structure/requirements.txt new file mode 100644 index 0000000000..1ecee01bab --- /dev/null +++ b/experiments/pipeline/examples/project_structure/requirements.txt @@ -0,0 +1,2 @@ +python-dlt[bigquery]==0.1.0rc14 +requests \ No newline at end of file diff --git a/experiments/pipeline/examples/secrets_and_config.md b/experiments/pipeline/examples/secrets_and_config.md new file mode 100644 index 0000000000..f3ec694f36 --- /dev/null +++ b/experiments/pipeline/examples/secrets_and_config.md @@ -0,0 +1,155 @@ +## Example +How config values and secrets are handled should promote good behavior + +1. secret values should never be present in the pipeline code +2. config values can be provided, changed etc. when pipeline is deployed +3. still it must be easy and intuitive + +For the source extractor function below (reads selected tab from google sheets) we can pass config values in following ways: + +```python + +import dlt +from dlt.destinations import bigquery + + +@dlt.source +def google_sheets(spreadsheet_id, tab_names, credentials, only_strings=False): + sheets = build('sheets', 'v4', credentials=Services.from_json(credentials)) + tabs = [] + for tab_name in tab_names: + data = sheets.get(spreadsheet_id, tab_name).execute().values() + tabs.append(dlt.resource(data, name=tab_name)) + return tabs + +# WRONG: provide all values directly - wrong but possible. secret values should never be present in the code! +google_sheets("23029402349032049", ["tab1", "tab2"], credentials={"private_key": ""}).run(destination=bigquery) + +# OPTION A: provide config values directly and secrets via automatic injection mechanism (see later) +# `credentials` value will be provided by the `source` decorator +# `spreadsheet_id` and `tab_names` take default values from the arguments below but may be overwritten by the decorator via config providers (see later) +google_sheets("23029402349032049", ["tab1", "tab2"]).run(destination=bigquery) + + +# OPTION B: all values are injected so there are no defaults and config values must be present in the providers +google_sheets().run(destination=bigquery) + + +# OPTION C: we use `dlt.secrets` and `dlt.config` to explicitly take those values from providers in the way we control (not recommended but straightforward) +google_sheets(dlt.config["sheet_id"], dlt.config["tabs"], dlt.secrets["gcp_credentials"]).run(destination=bigquery) +``` + +## Injection mechanism +By the magic of @dlt.source decorator + +The signature of the function `google_sheets` is also defining the structure of the configuration and secrets. + +When `google_sheets` function is called the decorator takes every input parameter and uses its value as initial. +Then it looks into `providers` if the value is not overwritten there. +It does the same for all arguments that were not in the call but are specified in function signature. +Then it calls the original function with updated input arguments thus passing config and secrets to it. + +## Providers +When config or secret values are needed, `dlt` looks for them in providers. In case of `google_sheets()` it will always look for: `spreadsheet_id`, `tab_names`, `credentials` and `strings_only`. + +Providers form a hierarchy. At the top are environment variables, then `secrets.toml` and `config.toml` files. Providers like google, aws, azure vaults can be inserted after the environment provider. +For example if `spreadsheet_id` is in environemtn, dlt does not look into other provieers. + +The values passed in the code directly are the lowest in provider hierarchy. + +## Namespaces +Config and secret values can be grouped in namespaces. Easiest way to visualize it is via `toml` files. + +This is valid for OPTION A and OPTION B + +**secrets.toml** +```toml +client_email = +private_key = +project_id = +``` +**config.toml** +```toml +spreadsheet_id="302940230490234903294" +tab_names=["tab1", "tab2"] +``` + +**alternative secrets.toml** +**secrets.toml** +```toml +[credentials] +client_email = +private_key = +project_id = +``` + +where `credentials` is name of the parameter from `google_sheet`. This parameter is a namespace for keys it contains and namespace are *optional* + +For OPTION C user uses its own custom keys to get credentials so: +**secrets.toml** +```toml +[gcp_credentials] +client_email = +private_key = +project_id = +``` +**config.toml** +```toml +sheet_id="302940230490234903294" +tabs=["tab1", "tab2"] +``` + +But what about `bigquery` credentials? In the case above it will reuse the credentials from **secrets.toml** (in OPTION A and B) but what if we need different credentials? + +Dlt has a nice optional namespace structure to handle all conflicts. It becomes useful in advanced cases like above. The secrets and config files may look as follows (and they will work with OPTION A and B) + +**secrets.toml** +```toml +[source.credentials] +client_email = +private_key = +project_id = + +[destination.credentials] +client_email = +private_key = +project_id = + +``` +**config.toml** +```toml +[source] +spreadsheet_id="302940230490234903294" +tab_names=["tab1", "tab2"] +``` + +How namespaces work in environment variables? they are prefix for the key so to get `spreadsheet_id` `dlt` will look for + +`SOURCE__SPREADSHEET_ID` first and `SPREADSHEET_ID` second + +## Interesting / Advanced stuff. + +The approach above makes configs and secrets explicit and autogenerates required lookups. It lets me for example **generate deployments** and **code templates for pipeline scripts** automatically because I know what are the config parameters and I have total control over users code and final values via the decorator. + +There's more cool stuff + +Here's how professional source function should look like + +```python + + +@dlt.source +def google_sheets(spreadsheet_id: str, tab_names: List[str], credentials: TCredentials, only_strings=False): + sheets = build('sheets', 'v4', credentials=Services.from_json(credentials)) + tabs = [] + for tab_name in tab_names: + data = sheets.get(spreadsheet_id, tab_name).execute().values() + tabs.append(dlt.resource(data, name=tab_name)) + return tabs +``` + +Here I provide typing so I can type check injected values so no junk data gets passed to the function. + +> I also tell which argument is secret via `TCredentials` that let's me control for the case when user is putting secret values in `config.toml` or some other unsafe provider (and generate even better templates) + +We could go even deeper here (ie. configurations `spec` may be explicitly declared via python `dataclasses`, may be embedded in one another etc. -> it comes useful when writing something really complicated) \ No newline at end of file diff --git a/experiments/pipeline/examples/stream_resources.md b/experiments/pipeline/examples/stream_resources.md new file mode 100644 index 0000000000..b46a2762ac --- /dev/null +++ b/experiments/pipeline/examples/stream_resources.md @@ -0,0 +1,4 @@ +Advanced: +There are stream resources that contain many data types ie. RASA Tracker so a single resource may map to many tables: +1. hints can also be functions and lambdas to create dynamic hints based on data items yielded +2. I will re-introduce the `with_table` modifier function from v1 - it is less efficient but more intuitive for the user \ No newline at end of file diff --git a/experiments/pipeline/examples/working_with_schemas.md b/experiments/pipeline/examples/working_with_schemas.md new file mode 100644 index 0000000000..29d93cc6b1 --- /dev/null +++ b/experiments/pipeline/examples/working_with_schemas.md @@ -0,0 +1,115 @@ +## General approach to define schemas + +## Schema components + +### Schema content hash and version +Each schema file contains content based hash `version_hash` that is used to +1. detect manual changes to schema (ie. user edits content) +2. detect if the destination database schema is synchronized with the file schema + +Each time the schema is saved, the version hash is updated. + +Each schema contains also numeric version which increases automatically whenever schema is updated and saved. This version is mostly for informative purposes and currently the user can easily reset it by wiping out the pipeline working dir (until we restore the current schema from the destination) + +> Currently the destination schema sync procedure uses the numeric version. I'm changing it to hash based versioning. + +### Normalizer and naming convention +The data normalizer and the naming convention are part of the schema configuration. In principle the source can set own naming convention or json unpacking mechanism. Or user can overwrite those in `config.toml` + +#### Relational normalizer config +Yes those are part of the normalizer module and can be plugged in. +1. column propagation from parent -> child +2. nesting level +3. parent -> child table linking type +### Global hints, preferred data type hints, data type autodetectors + +## Working with schema files +`dlt` automates working with schema files by setting up schema import and export folders. Settings are available via config providers (ie. `config.toml`) or via `dlt.pipeline(import_schema_path, export_schema_path)` settings. Example: +```python +dlt.pipeline(import_schema_path="schemas/import", export_schema_path="schemas/export") +``` +will create following folder structure in project root folder +``` +schemas + |---import/ + |---export/ +``` + +Which will expose pipeline schemas to the user in `yml` format. + +1. When new pipeline is created and source function is extracted for the first time a new schema is added to pipeline. This schema is created out of global hints and resource hints present in the source extractor function. It **does not depend on the data - which happens in normalize stage**. +2. Every such new schema will be saved to `import` folder (if not existing there already) and used as initial version for all future pipeline runs. +3. Once schema is present in `import` folder, **it is writable by the user only**. +4. Any change to the schemas in that folder are detected and propagated to the pipeline automatically on the next run (in fact any call to `Pipeline` object does that sync.). It means that after an user update, the schema in `import` folder resets all the automatic updates from the data. +4. Otherwise **the schema evolves automatically in the normalize stage** and each update is saved in `export` folder. The export folder is **writable by dlt only** and provides the actual view of the schema. +5. The `export` and `import` folders may be the same. In that case the evolved schema is automatically "accepted" as the initial one. + + +## Working with schema in code +`dlt` user can "check-out" any pipeline schema for modification in the code. + +> I do not have any cool API to work with the table, columns and other hints in the code - the schema is a typed dictionary and currently it is the only way. + +`dlt` will "commit" all the schema changes with any call to `run`, `extract`, `normalize` or `load` methods. + +Examples: + +```python +# extract some to "table" resource using default schema +p = dlt.pipeline(destination=redshift) +p.extract([1,2,3,4], name="table") +# get live schema +schema = p.default_schema +# we want the list data to be text, not integer +schema.tables["table"]["columns"]["value"] = schema_utils.new_column("value", "text") +# `run` will apply schema changes and run the normalizer and loader for already extracted data +p.run() +``` + +> The `normalize` stage creates standalone load packages each containing data and schema with particular version. Those packages are of course not impacted by the "live" schema changes. + +## Attaching schemas to sources +The general approach when creating a new pipeline is to setup a few global schema settings and then let the table and column schemas to be generated from the resource hints and data itself. + +> I do not have any cool "schema builder" api yet to se the global settings. + +Example: + +```python + +schema: Schema = None + +def setup_schema(nesting_level, hash_names_convention=False): + nonlocal schema + + # get default normalizer config + normalizer_conf = dlt.schema.normalizer_config() + # set hash names convention which produces short names without clashes but very ugly + if short_names_convention: + normalizer_conf["names"] = dlt.common.normalizers.names.hash_names + # remove date detector and add type detector that forces all fields to strings + normalizer_conf["detections"].remove("iso_timestamp") + normalizer_conf["detections"].insert(0, "all_text") + + # apply normalizer conf + schema = Schema("createx", normalizer_conf) + # set nesting level, yeah it's ugly + schema._normalizers_config["json"].setdefault("config", {})["max_nesting"] = nesting_level + +# apply schema to the source +@dlt.source(schema=schema) +def createx(): + ... + +``` + +Two other behaviors are supported +1. bare `dlt.source` will create empty schema with the source name +2. `dlt.source(name=...)` will first try to load `{name}_schema.yml` from the same folder the source python file exist. If not found, new empty schema will be created + + +## Open issues + +1. Name clashes. +2. Lack of lineage. +3. Names, types and hints interpretation depend on destination diff --git a/experiments/pipeline/exceptions.py b/experiments/pipeline/exceptions.py deleted file mode 100644 index af1df29f53..0000000000 --- a/experiments/pipeline/exceptions.py +++ /dev/null @@ -1,81 +0,0 @@ -from typing import Any, Sequence -from dlt.common.exceptions import DltException -from dlt.common.telemetry import TRunMetrics -from experiments.pipeline.typing import TPipelineStep - - -class PipelineException(DltException): - pass - - -class MissingDependencyException(PipelineException): - def __init__(self, caller: str, dependencies: Sequence[str], appendix: str = "") -> None: - self.caller = caller - self.dependencies = dependencies - super().__init__(self._get_msg(appendix)) - - def _get_msg(self, appendix: str) -> str: - msg = f""" -You must install additional dependencies to run {self.caller}. If you use pip you may do the following: - -{self._to_pip_install()} -""" - if appendix: - msg = msg + "\n" + appendix - return msg - - def _to_pip_install(self) -> str: - return "\n".join([f"pip install {d}" for d in self.dependencies]) - - -class NoPipelineException(PipelineException): - def __init__(self) -> None: - super().__init__("Please create or restore pipeline before using this function") - - -class PipelineConfigMissing(PipelineException): - def __init__(self, config_elem: str, step: TPipelineStep, help: str = None) -> None: - self.config_elem = config_elem - self.step = step - msg = f"Configuration element {config_elem} was not provided and {step} step cannot be executed" - if help: - msg += f"\n{help}\n" - super().__init__(msg) - - -class PipelineConfiguredException(PipelineException): - def __init__(self, f_name: str) -> None: - super().__init__(f"{f_name} cannot be called on already configured or restored pipeline.") - - -class InvalidPipelineContextException(PipelineException): - def __init__(self) -> None: - super().__init__("There may be just one active pipeline in single python process. To activate current pipeline call `activate` method") - - -class CannotRestorePipelineException(PipelineException): - def __init__(self, reason: str) -> None: - super().__init__(reason) - - -class SqlClientNotAvailable(PipelineException): - def __init__(self, client_type: str) -> None: - super().__init__(f"SQL Client not available in {client_type}") - - -class InvalidIteratorException(PipelineException): - def __init__(self, iterator: Any) -> None: - super().__init__(f"Unsupported source iterator or iterable type: {type(iterator).__name__}") - - -class InvalidItemException(PipelineException): - def __init__(self, item: Any) -> None: - super().__init__(f"Source yielded unsupported item type: {type(item).__name}. Only dictionaries, sequences and deferred items allowed.") - - -class PipelineStepFailed(PipelineException): - def __init__(self, step: TPipelineStep, exception: BaseException, run_metrics: TRunMetrics) -> None: - self.stage = step - self.exception = exception - self.run_metrics = run_metrics - super().__init__(f"Pipeline execution failed at stage {step} with exception:\n\n{type(exception)}\n{exception}") diff --git a/experiments/pipeline/pipeline.py b/experiments/pipeline/pipeline.py deleted file mode 100644 index dac3425811..0000000000 --- a/experiments/pipeline/pipeline.py +++ /dev/null @@ -1,481 +0,0 @@ -import os -from collections import abc -import tempfile -from contextlib import contextmanager -from copy import deepcopy -from functools import wraps -from typing import Any, List, Iterable, Iterator, Mapping, NewType, Optional, Sequence, Type, TypedDict, Union, overload -from operator import itemgetter -from prometheus_client import REGISTRY - -from dlt.common import json, logger, signals -from dlt.common.sources import DLT_METADATA_FIELD, with_table_name -from dlt.common.typing import DictStrAny, StrAny, TFun, TSecretValue, TAny - -from dlt.common.runners import pool_runner as runner, TRunMetrics, initialize_runner -from dlt.common.schema.utils import normalize_schema_name -from dlt.common.storages.live_schema_storage import LiveSchemaStorage -from dlt.common.storages.normalize_storage import NormalizeStorage -from dlt.common.storages.schema_storage import SchemaStorage - -from dlt.common.configuration import make_configuration, RunConfiguration, NormalizeVolumeConfiguration, SchemaVolumeConfiguration, ProductionNormalizeVolumeConfiguration -from dlt.common.schema.schema import Schema -from dlt.common.file_storage import FileStorage -from dlt.common.utils import is_interactive, uniq_id - -from dlt.extract.extractor_storage import ExtractorStorageBase -from dlt.load.typing import TLoaderCapabilities -from dlt.normalize.configuration import configuration as normalize_configuration -from dlt.normalize import Normalize -from dlt.load.client_base import SqlClientBase, SqlJobClientBase -from dlt.load.configuration import LoaderClientDwhConfiguration, configuration as loader_configuration -from dlt.load import Load - -from experiments.pipeline.configuration import get_config -from experiments.pipeline.exceptions import PipelineConfigMissing, PipelineConfiguredException, MissingDependencyException, PipelineStepFailed -from experiments.pipeline.sources import SourceTables, TResolvableDataItem - - -TConnectionString = NewType("TConnectionString", str) -TSourceState = NewType("TSourceState", DictStrAny) - -TCredentials = Union[TConnectionString, StrAny] - -class TPipelineState(TypedDict): - pipeline_name: str - default_dataset: str - # is_transient: bool - default_schema_name: Optional[str] - # pipeline_secret: TSecretValue - destination_name: Optional[str] - # schema_sync_path: Optional[str] - - -# class TPipelineState() -# sources: Dict[str, TSourceState] - - -class PipelineConfiguration(RunConfiguration): - WORKING_DIR: Optional[str] = None - PIPELINE_SECRET: Optional[TSecretValue] = None - drop_existing_data: bool = False - - @classmethod - def check_integrity(cls) -> None: - if cls.PIPELINE_SECRET: - cls.PIPELINE_SECRET = uniq_id() - - -class Pipeline: - - ACTIVE_INSTANCE: "Pipeline" = None - STATE_FILE = "state.json" - - def __new__(cls: Type["Pipeline"]) -> "Pipeline": - cls.ACTIVE_INSTANCE = super().__new__(cls) - return cls.ACTIVE_INSTANCE - - def __init__(self): - # pipeline is not configured yet - # self.is_configured = False - # self.pipeline_name: str = None - # self.pipeline_secret: str = None - # self.default_schema_name: str = None - # self.default_dataset_name: str = None - # self.working_dir: str = None - # self.is_transient: bool = None - self.CONFIG: Type[PipelineConfiguration] = None - self.root_folder: str = None - - self._initial_values: DictStrAny = {} - self._state: TPipelineState = {} - self._pipeline_storage: FileStorage = None - self._extractor_storage: ExtractorStorageBase = None - self._schema_storage: LiveSchemaStorage = None - - def only_not_configured(f: TFun) -> TFun: - - @wraps(f) - def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: - if self.CONFIG: - raise PipelineConfiguredException(f.__name__) - return f(self, *args, **kwargs) - - return _wrap - - def maybe_default_config(f: TFun) -> TFun: - - @wraps(f) - def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: - if not self.CONFIG: - self.configure() - return f(self, *args, **kwargs) - - return _wrap - - def with_state_sync(f: TFun) -> TFun: - - @wraps(f) - def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: - with self._managed_state(): - return f(self, *args, **kwargs) - - return _wrap - - def with_schemas_sync(f: TFun) -> TFun: - - @wraps(f) - def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: - for name in self._schema_storage.live_schemas: - # refresh live schemas in storage or import schema path - self._schema_storage.commit_live_schema(name) - return f(self, *args, **kwargs) - - return _wrap - - - @overload - def configure(self, - pipeline_name: str = None, - working_dir: str = None, - pipeline_secret: TSecretValue = None, - drop_existing_data: bool = False, - import_schema_path: str = None, - export_schema_path: str = None, - destination_name: str = None, - log_level: str = "INFO" - ) -> None: - ... - - - @only_not_configured - @with_state_sync - def configure(self, **kwargs: Any) -> None: - # keep the locals to be able to initialize configs at any time - self._initial_values.update(**kwargs) - # resolve pipeline configuration - self.CONFIG = self._get_config(PipelineConfiguration) - - # use system temp folder if not specified - if not self.CONFIG.WORKING_DIR: - self.CONFIG.WORKING_DIR = tempfile.gettempdir() - self.root_folder = os.path.join(self.CONFIG.WORKING_DIR, self.CONFIG.PIPELINE_NAME) - self._set_common_initial_values() - - # create pipeline working dir - self._pipeline_storage = FileStorage(self.root_folder, makedirs=False) - - # remove existing pipeline if requested - if self._pipeline_storage.has_folder(".") and self.CONFIG.drop_existing_data: - self._pipeline_storage.delete_folder(".") - - # restore pipeline if folder exists and contains state - if self._pipeline_storage.has_file(Pipeline.STATE_FILE): - self._restore_pipeline() - else: - self._create_pipeline() - - # create schema storage - self._schema_storage = LiveSchemaStorage(self._get_config(SchemaVolumeConfiguration), makedirs=True) - # create extractor storage - self._extractor_storage = ExtractorStorageBase( - "1.0.0", - True, - FileStorage(os.path.join(self.root_folder, "extract"), makedirs=True), - self._ensure_normalize_storage() - ) - - initialize_runner(self.CONFIG) - - - def _get_config(self, spec: Type[TAny], accept_partial: bool = False) -> Type[TAny]: - print(self._initial_values) - return make_configuration(spec, spec, initial_values=self._initial_values, accept_partial=accept_partial) - - - @overload - def extract( - self, - data: Union[Iterator[TResolvableDataItem], Iterable[TResolvableDataItem]], - table_name = None, - write_disposition = None, - parent = None, - columns = None, - max_parallel_data_items: int = 20, - schema: Schema = None - ) -> None: - ... - - @overload - def extract( - self, - data: SourceTables, - max_parallel_iterators: int = 1, - max_parallel_data_items: int = 20, - schema: Schema = None - ) -> None: - ... - - @maybe_default_config - @with_schemas_sync - @with_state_sync - def extract( - self, - data: Union[SourceTables, Iterator[TResolvableDataItem], Iterable[TResolvableDataItem]], - table_name = None, - write_disposition = None, - parent = None, - columns = None, - max_parallel_iterators: int = 1, - max_parallel_data_items: int = 20, - schema: Schema = None - ) -> None: - self._schema_storage.save_schema(schema) - self._state["default_schema_name"] = schema.name - # TODO: apply hints to table - - # check if iterator or iterable is supported - # if isinstance(items, str) or isinstance(items, dict) or not - # TODO: check if schema exists - with self._managed_state(): - default_table_name = table_name or self.CONFIG.PIPELINE_NAME - # TODO: this is not very effective - we consume iterator right away, better implementation needed where we stream iterator to files directly - all_items: List[DictStrAny] = [] - for item in data: - # dispatch items by type - if callable(item): - item = item() - if isinstance(item, dict): - all_items.append(item) - elif isinstance(item, abc.Sequence): - all_items.extend(item) - # react to CTRL-C and shutdowns from controllers - signals.raise_if_signalled() - - try: - self._extract_iterator(default_table_name, all_items) - except Exception: - raise PipelineStepFailed("extract", self.last_run_exception, runner.LAST_RUN_METRICS) - - # @maybe_default_config - # @with_schemas_sync - # @with_state_sync - # def extract_many() -> None: - # pass - - @with_schemas_sync - def normalize(self, dry_run: bool = False, workers: int = 1, max_events_in_chunk: int = 100000) -> None: - if is_interactive() and workers > 1: - raise NotImplementedError("Do not use normalize workers in interactive mode ie. in notebook") - # set parameters to be passed to config - normalize = self._configure_normalize({ - "WORKERS": workers, - "MAX_EVENTS_IN_CHUNK": max_events_in_chunk, - "POOL_TYPE": "thread" if workers == 1 else "process" - }) - try: - ec = runner.run_pool(normalize.CONFIG, normalize) - # in any other case we raise if runner exited with status failed - if runner.LAST_RUN_METRICS.has_failed: - raise PipelineStepFailed("normalize", self.last_run_exception, runner.LAST_RUN_METRICS) - return ec - except Exception as r_ex: - # if EXIT_ON_EXCEPTION flag is set, exception will bubble up directly - raise PipelineStepFailed("normalize", self.last_run_exception, runner.LAST_RUN_METRICS) from r_ex - finally: - signals.raise_if_signalled() - - @with_schemas_sync - @with_state_sync - def load( - self, - destination_name: str = None, - default_dataset: str = None, - credentials: TCredentials = None, - raise_on_failed_jobs = False, - raise_on_incompatible_schema = False, - always_drop_dataset = False, - dry_run: bool = False, - max_parallel_loads: int = 20, - normalize_workers: int = 1 - ) -> None: - self._resolve_load_client_config() - # check if anything to normalize - if len(self._extractor_storage.normalize_storage.list_files_to_normalize_sorted()) > 0: - self.normalize(dry_run=dry_run, workers=normalize_workers) - # then load - print(locals()) - load = self._configure_load(locals(), credentials) - runner.run_pool(load.CONFIG, load) - if runner.LAST_RUN_METRICS.has_failed: - raise PipelineStepFailed("load", self.last_run_exception, runner.LAST_RUN_METRICS) - - def activate(self) -> None: - # make this instance the active one - pass - - @property - def schemas(self) -> Mapping[str, Schema]: - return self._schema_storage - - @property - def default_schema(self) -> Schema: - return self.schemas[self._state.get("default_schema_name")] - - @property - def last_run_exception(self) -> BaseException: - return runner.LAST_RUN_EXCEPTION - - def _create_pipeline(self) -> None: - self._pipeline_storage.create_folder(".", exists_ok=True) - - def _restore_pipeline(self) -> None: - self._restore_state() - - def _ensure_normalize_storage(self) -> NormalizeStorage: - return NormalizeStorage(True, self._get_config(NormalizeVolumeConfiguration)) - - def _configure_normalize(self, initial_values: DictStrAny) -> Normalize: - destination_name = self._ensure_destination_name() - format = self._get_loader_capabilities(destination_name)["preferred_loader_file_format"] - # create normalize config - initial_values.update({ - "LOADER_FILE_FORMAT": format, - "ADD_EVENT_JSON": False - }) - # apply schema storage config - # initial_values.update(self._schema_storage.C.as_dict()) - # apply common initial settings - initial_values.update(self._initial_values) - C = normalize_configuration(initial_values=initial_values) - print(C.as_dict()) - # shares schema storage with the pipeline so we do not need to install - return Normalize(C, schema_storage=self._schema_storage) - - def _configure_load(self, loader_initial: DictStrAny, credentials: TCredentials = None) -> Load: - # get destination or raise - destination_name = self._ensure_destination_name() - # import load client for given destination or raise - self._get_loader_capabilities(destination_name) - # get default dataset or raise - default_dataset = self._ensure_default_dataset() - - loader_initial.update({ - "DELETE_COMPLETED_JOBS": True, - "CLIENT_TYPE": destination_name - }) - loader_initial.update(self._initial_values) - - loader_client_initial = { - "DEFAULT_DATASET": default_dataset, - "DEFAULT_SCHEMA_NAME": self._state.get("default_schema_name") - } - if credentials: - loader_client_initial.update(credentials) - - C = loader_configuration(initial_values=loader_initial) - return Load(C, REGISTRY, client_initial_values=loader_client_initial, is_storage_owner=False) - - def _set_common_initial_values(self) -> None: - self._initial_values.update({ - "IS_SINGLE_RUN": True, - "EXIT_ON_EXCEPTION": True, - "LOAD_VOLUME_PATH": os.path.join(self.root_folder, "load"), - "NORMALIZE_VOLUME_PATH": os.path.join(self.root_folder, "normalize"), - "SCHEMA_VOLUME_PATH": os.path.join(self.root_folder, "schemas") - }) - - def _get_loader_capabilities(self, destination_name: str) -> TLoaderCapabilities: - try: - return Load.loader_capabilities(destination_name) - except ImportError: - raise MissingDependencyException( - f"{destination_name} destination", - [f"python-dlt[{destination_name}]"], - "Dependencies for specific destinations are available as extras of python-dlt" - ) - - def _resolve_load_client_config(self) -> Type[LoaderClientDwhConfiguration]: - return get_config( - LoaderClientDwhConfiguration, - initial_values={ - "client_type": self._initial_values.get("destination_name"), - "default_dataset": self._initial_values.get("default_dataset") - }, - accept_partial=True - ) - - def _ensure_destination_name(self) -> str: - d_n = self._resolve_load_client_config().CLIENT_TYPE - if not d_n: - raise PipelineConfigMissing( - "destination_name", - "normalize", - "Please provide `destination_name` argument to `config` or `load` method or via pipeline config file or environment var." - ) - return d_n - - def _ensure_default_dataset(self) -> str: - d_n = self._resolve_load_client_config().DEFAULT_DATASET - if not d_n: - d_n = normalize_schema_name(self.CONFIG.PIPELINE_NAME) - return d_n - - def _extract_iterator(self, default_table_name: str, items: Sequence[DictStrAny]) -> None: - try: - for idx, i in enumerate(items): - if not isinstance(i, dict): - # TODO: convert non dict types into dict - items[idx] = i = {"v": i} - if DLT_METADATA_FIELD not in i or i.get(DLT_METADATA_FIELD, None) is None: - # set default table name - with_table_name(i, default_table_name) - - load_id = uniq_id() - self._extractor_storage.save_json(f"{load_id}.json", items) - self._extractor_storage.commit_events( - self.default_schema.name, - self._extractor_storage.storage._make_path(f"{load_id}.json"), - default_table_name, - len(items), - load_id - ) - - runner.LAST_RUN_METRICS = TRunMetrics(was_idle=False, has_failed=False, pending_items=0) - except Exception as ex: - logger.exception("extracting iterator failed") - runner.LAST_RUN_METRICS = TRunMetrics(was_idle=False, has_failed=True, pending_items=0) - runner.LAST_RUN_EXCEPTION = ex - raise - - @contextmanager - def _managed_state(self) -> Iterator[None]: - backup_state = deepcopy(self._state) - try: - yield - except Exception: - # restore old state - self._state.clear() - self._state.update(backup_state) - raise - else: - # persist old state - # TODO: compare backup and new state, save only if different - self._pipeline_storage.save(Pipeline.STATE_FILE, json.dumps(self._state)) - - def _restore_state(self) -> None: - self._state.clear() - restored_state: DictStrAny = json.loads(self._pipeline_storage.load(Pipeline.STATE_FILE)) - self._state.update(restored_state) - - @property - def is_active(self) -> bool: - return id(self) == id(Pipeline.ACTIVE_INSTANCE) - - @property - def has_pending_loads(self) -> bool: - # TODO: check if has pending normalizer and loader data - pass - -# active instance always present -Pipeline.ACTIVE_INSTANCE = Pipeline() diff --git a/experiments/pipeline/sources.py b/experiments/pipeline/sources.py deleted file mode 100644 index b2b5d1ca91..0000000000 --- a/experiments/pipeline/sources.py +++ /dev/null @@ -1,86 +0,0 @@ -from collections import abc -from typing import Iterable, Iterator, List, Union, Awaitable, Callable, Sequence, TypeVar, cast - -from dlt.common.schema.schema import Schema -from dlt.common.schema.typing import TTableSchema -from dlt.common.typing import TDataItem - - -TDirectDataItem = Union[TDataItem, Sequence[TDataItem]] -TDeferredDataItem = Callable[[], TDirectDataItem] -TAwaitableDataItem = Awaitable[TDirectDataItem] -TResolvableDataItem = Union[TDirectDataItem, TDeferredDataItem, TAwaitableDataItem] - -# TBoundItem = TypeVar("TBoundItem", bound=TDataItem) -# TDeferreBoundItem = Callable[[], TBoundItem] - - -class TableMetadataMixin: - def __init__(self, table_schema: TTableSchema, schema: Schema = None, selected_tables: List[str] = None): - self._table_schema = table_schema - self.schema = schema - self._table_name = table_schema["name"] - self.__name__ = self._table_name - self.selected_tables = selected_tables - - @property - def table_schema(self): - # TODO: returns unified table schema by merging _schema and _table with table taking precedence - return self._table_schema - - -_i_info: TableMetadataMixin = None - - -def extractor_resolver(i: Union[Iterator[TResolvableDataItem], Iterable[TResolvableDataItem]], selected_tables: List[str] = None) -> Iterator[TDataItem]: - - if not isinstance(i, abc.Iterator): - i = iter(i) - - # for item in i: - - - -class TableIterable(abc.Iterable, TableMetadataMixin): - def __init__(self, i, table, schema = None, selected_tables: List[str] = None): - self._data = i - super().__init__(table, schema, selected_tables) - - def __iter__(self): - # TODO: this should resolve the _data like we do in the extract method: all awaitables and deferred items are resolved - # possibly in parallel. - resolved_data = extractor_resolver(self._data) - return TableIterator(resolved_data, self._table_schema, self.schema, self.selected_tables) - - - -class TableIterator(abc.Iterator, TableMetadataMixin): - def __init__(self, i, table, schema = None, selected_tables: List[str] = None): - self.i = i - super().__init__(table, schema, selected_tables) - - # def __next__(self): - # # export metadata to global variable so it can be read by extractor - # # TODO: remove this hack if possible - # global _i_info - # _i_info = cast(self, TableMetadataMixin) - - # if callable(self._table_name): - # else: - # if no table filter selected - # return next(self.i) - # while True: - # ni = next(self.i) - # if callable(self._table_name): - # # table name is a lambda, so resolve table name - # t_n = self._table_name(ni) - # return - - # def __iter__(self): - # return self - - -class SourceTables(List[TableIterable]): - pass - - diff --git a/experiments/pipeline/typing.py b/experiments/pipeline/typing.py deleted file mode 100644 index d38bb9cba4..0000000000 --- a/experiments/pipeline/typing.py +++ /dev/null @@ -1,4 +0,0 @@ -from typing import Literal - - -TPipelineStep = Literal["extract", "normalize", "load"] \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index 33f3f0824c..f3eebe74e5 100644 --- a/mypy.ini +++ b/mypy.ini @@ -10,6 +10,7 @@ check_untyped_defs=true warn_return_any=true namespace_packages=true warn_unused_ignores=true +enable_incomplete_features=true ;disallow_any_generics=false diff --git a/poetry.lock b/poetry.lock index db4b1ed3ab..08de21b649 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3,7 +3,7 @@ name = "agate" version = "1.6.3" description = "A data analysis library that is optimized for humans instead of machines." category = "main" -optional = true +optional = false python-versions = "*" [package.dependencies] @@ -84,7 +84,7 @@ name = "babel" version = "2.10.3" description = "Internationalization utilities" category = "main" -optional = true +optional = false python-versions = ">=3.6" [package.dependencies] @@ -109,38 +109,6 @@ test = ["coverage (>=4.5.4)", "fixtures (>=3.0.0)", "flake8 (>=4.0.0)", "stestr toml = ["toml"] yaml = ["pyyaml"] -[[package]] -name = "boto3" -version = "1.24.76" -description = "The AWS SDK for Python" -category = "main" -optional = true -python-versions = ">= 3.7" - -[package.dependencies] -botocore = ">=1.27.76,<1.28.0" -jmespath = ">=0.7.1,<2.0.0" -s3transfer = ">=0.6.0,<0.7.0" - -[package.extras] -crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] - -[[package]] -name = "botocore" -version = "1.27.76" -description = "Low-level, data-driven core of boto 3." -category = "main" -optional = true -python-versions = ">= 3.7" - -[package.dependencies] -jmespath = ">=0.7.1,<2.0.0" -python-dateutil = ">=2.1,<3.0.0" -urllib3 = ">=1.25.4,<1.27" - -[package.extras] -crt = ["awscrt (==0.14.0)"] - [[package]] name = "cachetools" version = "5.2.0" @@ -162,7 +130,7 @@ name = "cffi" version = "1.15.1" description = "Foreign Function Interface for Python calling C code." category = "main" -optional = true +optional = false python-versions = "*" [package.dependencies] @@ -184,7 +152,7 @@ name = "click" version = "8.1.3" description = "Composable command line interface toolkit" category = "main" -optional = true +optional = false python-versions = ">=3.7" [package.dependencies] @@ -198,29 +166,13 @@ category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" -[[package]] -name = "dbt-bigquery" -version = "1.0.0" -description = "The BigQuery adapter plugin for dbt" -category = "main" -optional = true -python-versions = ">=3.7" - -[package.dependencies] -dbt-core = ">=1.0.0,<1.1.0" -google-api-core = ">=1.16.0,<3" -google-cloud-bigquery = ">=1.25.0,<3" -google-cloud-core = ">=1.3.0,<3" -googleapis-common-protos = ">=1.6.0,<2" -protobuf = ">=3.13.0,<4" - [[package]] name = "dbt-core" -version = "1.0.6" +version = "1.1.2" description = "With dbt, data analysts and engineers can build analytics the way engineers build applications." category = "main" -optional = true -python-versions = ">=3.7" +optional = false +python-versions = ">=3.7.2" [package.dependencies] agate = ">=1.6,<1.6.4" @@ -228,7 +180,7 @@ cffi = ">=1.9,<2.0.0" click = ">=7.0,<9" colorama = ">=0.3.9,<0.4.5" dbt-extractor = ">=0.4.1,<0.5.0" -hologram = "0.0.14" +hologram = ">=0.0.14,<=0.0.15" idna = ">=2.5,<4" isodate = ">=0.6,<0.7" Jinja2 = "2.11.3" @@ -236,11 +188,11 @@ logbook = ">=1.5,<1.6" MarkupSafe = ">=0.23,<2.1" mashumaro = "2.9" minimal-snowplow-tracker = "0.0.2" -networkx = ">=2.3,<3" +networkx = ">=2.3,<2.8.4" packaging = ">=20.9,<22.0" requests = "<3.0.0" sqlparse = ">=0.2.3,<0.5" -typing-extensions = ">=3.7.4,<3.11" +typing-extensions = ">=3.7.4" werkzeug = ">=1,<3" [[package]] @@ -248,34 +200,9 @@ name = "dbt-extractor" version = "0.4.1" description = "A tool to analyze and extract information from Jinja used in dbt projects." category = "main" -optional = true +optional = false python-versions = ">=3.6.1" -[[package]] -name = "dbt-postgres" -version = "1.0.6" -description = "The postgres adpter plugin for dbt (data build tool)" -category = "main" -optional = true -python-versions = ">=3.7" - -[package.dependencies] -dbt-core = "1.0.6" -psycopg2-binary = ">=2.8,<3.0" - -[[package]] -name = "dbt-redshift" -version = "1.0.1" -description = "The Redshift adapter plugin for dbt" -category = "main" -optional = true -python-versions = ">=3.7" - -[package.dependencies] -boto3 = ">=1.4.4,<2.0.0" -dbt-core = ">=1.0.0,<1.1.0" -dbt-postgres = ">=1.0.0,<1.1.0" - [[package]] name = "decopatch" version = "1.4.10" @@ -304,18 +231,6 @@ typing-extensions = ">=3.7.4.1" all = ["pytz (>=2019.1)"] dates = ["pytz (>=2019.1)"] -[[package]] -name = "fire" -version = "0.4.0" -description = "A library for automatically generating command line interfaces." -category = "main" -optional = false -python-versions = "*" - -[package.dependencies] -six = "*" -termcolor = "*" - [[package]] name = "flake8" version = "5.0.4" @@ -403,7 +318,7 @@ name = "future" version = "0.18.2" description = "Clean single-source support for Python 3 and 2" category = "main" -optional = true +optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" [[package]] @@ -614,14 +529,14 @@ dev = ["jinja2 (>=3.0.0,<3.1.0)", "towncrier (>=21,<22)", "sphinx-rtd-theme (>=0 [[package]] name = "hologram" -version = "0.0.14" +version = "0.0.15" description = "JSON schema generation from dataclasses" category = "main" -optional = true +optional = false python-versions = "*" [package.dependencies] -jsonschema = ">=3.0,<3.2" +jsonschema = ">=3.0,<4.0" python-dateutil = ">=2.8,<2.9" [[package]] @@ -636,7 +551,7 @@ python-versions = ">=3.5" name = "importlib-metadata" version = "4.12.0" description = "Read metadata from Python packages" -category = "main" +category = "dev" optional = false python-versions = ">=3.7" @@ -661,7 +576,7 @@ name = "isodate" version = "0.6.1" description = "An ISO 8601 date/time/duration parser and formatter" category = "main" -optional = true +optional = false python-versions = "*" [package.dependencies] @@ -672,7 +587,7 @@ name = "jinja2" version = "2.11.3" description = "A very fast and expressive template engine." category = "main" -optional = true +optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" [package.dependencies] @@ -681,14 +596,6 @@ MarkupSafe = ">=0.23" [package.extras] i18n = ["Babel (>=0.8)"] -[[package]] -name = "jmespath" -version = "1.0.1" -description = "JSON Matching Expressions" -category = "main" -optional = true -python-versions = ">=3.7" - [[package]] name = "json-logging" version = "1.4.1rc0" @@ -707,27 +614,27 @@ python-versions = ">=3.6" [[package]] name = "jsonschema" -version = "3.1.1" +version = "3.2.0" description = "An implementation of JSON Schema validation for Python" category = "main" -optional = true +optional = false python-versions = "*" [package.dependencies] attrs = ">=17.4.0" -importlib-metadata = "*" pyrsistent = ">=0.14.0" six = ">=1.11.0" [package.extras] format = ["idna", "jsonpointer (>1.13)", "rfc3987", "strict-rfc3339", "webcolors"] +format_nongpl = ["idna", "jsonpointer (>1.13)", "webcolors", "rfc3986-validator (>0.1.0)", "rfc3339-validator"] [[package]] name = "leather" version = "0.3.4" description = "Python charting for 80% of humans." category = "main" -optional = true +optional = false python-versions = "*" [package.dependencies] @@ -738,7 +645,7 @@ name = "logbook" version = "1.5.3" description = "A logging replacement for Python" category = "main" -optional = true +optional = false python-versions = "*" [package.extras] @@ -765,7 +672,7 @@ name = "markupsafe" version = "2.0.1" description = "Safely add untrusted strings to HTML/XML markup." category = "main" -optional = true +optional = false python-versions = ">=3.6" [[package]] @@ -773,7 +680,7 @@ name = "mashumaro" version = "2.9" description = "Fast serialization framework on top of dataclasses" category = "main" -optional = true +optional = false python-versions = ">=3.6" [package.dependencies] @@ -794,7 +701,7 @@ name = "minimal-snowplow-tracker" version = "0.0.2" description = "A minimal snowplow event tracker for Python. Add analytics to your Python and Django apps, webapps and games" category = "main" -optional = true +optional = false python-versions = "*" [package.dependencies] @@ -806,16 +713,16 @@ name = "msgpack" version = "1.0.4" description = "MessagePack serializer" category = "main" -optional = true +optional = false python-versions = "*" [[package]] name = "mypy" -version = "0.971" +version = "0.982" description = "Optional static typing for Python" category = "dev" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" [package.dependencies] mypy-extensions = ">=0.4.3" @@ -823,9 +730,9 @@ tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} typing-extensions = ">=3.10" [package.extras] -reports = ["lxml"] -python2 = ["typed-ast (>=1.4.0,<2)"] dmypy = ["psutil (>=4.0)"] +python2 = ["typed-ast (>=1.4.0,<2)"] +reports = ["lxml"] [[package]] name = "mypy-extensions" @@ -849,16 +756,16 @@ icu = ["PyICU (>=1.0.0)"] [[package]] name = "networkx" -version = "2.8.6" +version = "2.8.3" description = "Python package for creating and manipulating graphs and networks" category = "main" -optional = true +optional = false python-versions = ">=3.8" [package.extras] default = ["numpy (>=1.19)", "scipy (>=1.8)", "matplotlib (>=3.4)", "pandas (>=1.3)"] -developer = ["pre-commit (>=2.20)", "mypy (>=0.961)"] -doc = ["sphinx (>=5)", "pydata-sphinx-theme (>=0.9)", "sphinx-gallery (>=0.10)", "numpydoc (>=1.4)", "pillow (>=9.1)", "nb2plots (>=0.6)", "texext (>=0.6.6)"] +developer = ["pre-commit (>=2.19)", "mypy (>=0.960)"] +doc = ["sphinx (>=4.5)", "pydata-sphinx-theme (>=0.8.1)", "sphinx-gallery (>=0.10)", "numpydoc (>=1.3)", "pillow (>=9.1)", "nb2plots (>=0.6)", "texext (>=0.6.6)"] extra = ["lxml (>=4.6)", "pygraphviz (>=1.9)", "pydot (>=1.4.2)", "sympy (>=1.10)"] test = ["pytest (>=7.1)", "pytest-cov (>=3.0)", "codecov (>=2.1)"] @@ -886,12 +793,23 @@ name = "parsedatetime" version = "2.4" description = "Parse human-readable date/time text." category = "main" -optional = true +optional = false python-versions = "*" [package.dependencies] future = "*" +[[package]] +name = "pathvalidate" +version = "2.5.2" +description = "pathvalidate is a Python library to sanitize/validate a string such as filenames/file-paths/etc." +category = "main" +optional = false +python-versions = ">=3.6" + +[package.extras] +test = ["allpairspy", "click", "faker", "pytest (>=6.0.1)", "pytest-discord (>=0.0.6)", "pytest-md-report (>=0.0.12)"] + [[package]] name = "pbr" version = "5.10.0" @@ -1028,7 +946,7 @@ name = "pycparser" version = "2.21" description = "C parser in Python" category = "main" -optional = true +optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" [[package]] @@ -1052,10 +970,10 @@ diagrams = ["railroad-diagrams", "jinja2"] [[package]] name = "pyrsistent" -version = "0.18.1" +version = "0.19.1" description = "Persistent/Functional/Immutable data structures" category = "main" -optional = true +optional = false python-versions = ">=3.7" [[package]] @@ -1144,7 +1062,7 @@ name = "python-slugify" version = "6.1.2" description = "A Python slugify application that also handles Unicode" category = "main" -optional = true +optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" [package.dependencies] @@ -1158,15 +1076,15 @@ name = "pytimeparse" version = "1.1.8" description = "Time expression parser" category = "main" -optional = true +optional = false python-versions = "*" [[package]] name = "pytz" -version = "2022.2.1" +version = "2022.5" description = "World timezone definitions, modern and historical" category = "main" -optional = true +optional = false python-versions = "*" [[package]] @@ -1185,17 +1103,6 @@ category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" -[[package]] -name = "randomname" -version = "0.1.5" -description = "Generate random adj-noun names like docker and github." -category = "main" -optional = false -python-versions = "*" - -[package.dependencies] -fire = "*" - [[package]] name = "requests" version = "2.28.1" @@ -1225,20 +1132,6 @@ python-versions = ">=3.6,<4" [package.dependencies] pyasn1 = ">=0.1.3" -[[package]] -name = "s3transfer" -version = "0.6.0" -description = "An Amazon S3 Transfer Manager" -category = "main" -optional = true -python-versions = ">= 3.7" - -[package.dependencies] -botocore = ">=1.12.36,<2.0a.0" - -[package.extras] -crt = ["botocore[crt] (>=1.20.29,<2.0a.0)"] - [[package]] name = "semver" version = "2.13.0" @@ -1305,10 +1198,10 @@ python-versions = ">=3.6" [[package]] name = "sqlparse" -version = "0.4.2" +version = "0.4.3" description = "A non-validating SQL parser." category = "main" -optional = true +optional = false python-versions = ">=3.5" [[package]] @@ -1322,23 +1215,12 @@ python-versions = ">=3.8" [package.dependencies] pbr = ">=2.0.0,<2.1.0 || >2.1.0" -[[package]] -name = "termcolor" -version = "2.0.1" -description = "ANSI color formatting for output in terminal" -category = "main" -optional = false -python-versions = ">=3.7" - -[package.extras] -tests = ["pytest-cov", "pytest"] - [[package]] name = "text-unidecode" version = "1.3" description = "The most basic Text::Unidecode port" category = "main" -optional = true +optional = false python-versions = "*" [[package]] @@ -1426,11 +1308,11 @@ python-versions = "*" [[package]] name = "typing-extensions" -version = "3.10.0.2" -description = "Backported and Experimental Type Hints for Python 3.5+" +version = "4.4.0" +description = "Backported and Experimental Type Hints for Python 3.7+" category = "main" optional = false -python-versions = "*" +python-versions = ">=3.7" [[package]] name = "tzdata" @@ -1458,7 +1340,7 @@ name = "werkzeug" version = "2.1.2" description = "The comprehensive WSGI web application library." category = "main" -optional = true +optional = false python-versions = ">=3.7" [package.extras] @@ -1468,7 +1350,7 @@ watchdog = ["watchdog"] name = "zipp" version = "3.8.1" description = "Backport of pathlib-compatible object wrapper for zip files" -category = "main" +category = "dev" optional = false python-versions = ">=3.7" @@ -1478,7 +1360,6 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [extras] bigquery = ["grpcio", "google-cloud-bigquery", "google-cloud-bigquery-storage", "pyarrow"] -dbt = ["dbt-core", "GitPython", "dbt-redshift", "dbt-bigquery"] gcp = ["grpcio", "google-cloud-bigquery", "google-cloud-bigquery-storage", "pyarrow"] postgres = ["psycopg2-binary", "psycopg2cffi"] redshift = ["psycopg2-binary", "psycopg2cffi"] @@ -1486,7 +1367,7 @@ redshift = ["psycopg2-binary", "psycopg2cffi"] [metadata] lock-version = "1.1" python-versions = "^3.8,<3.11" -content-hash = "d04bbf2afa3c4f46ef5725465da8baad95da271da965408c208c2557b6af198a" +content-hash = "f3ce0afb16174d4f0b4e297adba698c13078f3f3cfee6526b776b8096720c33b" [metadata.files] agate = [ @@ -1509,8 +1390,6 @@ bandit = [ {file = "bandit-1.7.4-py3-none-any.whl", hash = "sha256:412d3f259dab4077d0e7f0c11f50f650cc7d10db905d98f6520a95a18049658a"}, {file = "bandit-1.7.4.tar.gz", hash = "sha256:2d63a8c573417bae338962d4b9b06fbc6080f74ecd955a092849e1e65c717bd2"}, ] -boto3 = [] -botocore = [] cachetools = [ {file = "cachetools-5.2.0-py3-none-any.whl", hash = "sha256:f9f17d2aec496a9aa6b76f53e3b614c965223c061982d434d160f930c698a9db"}, {file = "cachetools-5.2.0.tar.gz", hash = "sha256:6a94c6402995a99c3970cc7e4884bb60b4a8639938157eeed436098bf9831757"}, @@ -1591,14 +1470,7 @@ colorama = [ {file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"}, {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"}, ] -dbt-bigquery = [ - {file = "dbt-bigquery-1.0.0.tar.gz", hash = "sha256:e22442f00fcec155dcbfe8be351a11c35913fb6edd11bd5e52fafc3218abd12e"}, - {file = "dbt_bigquery-1.0.0-py3-none-any.whl", hash = "sha256:48778c89a37dd866ffd3718bf6b78e1139b7fb4cc0377f2feaa95e10dc3ce9c2"}, -] -dbt-core = [ - {file = "dbt-core-1.0.6.tar.gz", hash = "sha256:5155bc4e81aba9df1a9a183205c0a240a3ec08d4fb9377df4f0d4d4b96268be1"}, - {file = "dbt_core-1.0.6-py3-none-any.whl", hash = "sha256:20e8e4fdd9ad08a25b3fb7020ffbdfd3b9aa6339a63a3d125f3f6d3edc2605f2"}, -] +dbt-core = [] dbt-extractor = [ {file = "dbt_extractor-0.4.1-cp36-abi3-macosx_10_7_x86_64.whl", hash = "sha256:4dc715bd740e418d8dc1dd418fea508e79208a24cf5ab110b0092a3cbe96bf71"}, {file = "dbt_extractor-0.4.1-cp36-abi3-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:bc9e0050e3a2f4ea9fe58e8794bc808e6709a0c688ed710fc7c5b6ef3e5623ec"}, @@ -1617,20 +1489,11 @@ dbt-extractor = [ {file = "dbt_extractor-0.4.1-cp36-abi3-win_amd64.whl", hash = "sha256:35265a0ae0a250623b0c2e3308b2738dc8212e40e0aa88407849e9ea090bb312"}, {file = "dbt_extractor-0.4.1.tar.gz", hash = "sha256:75b1c665699ec0f1ffce1ba3d776f7dfce802156f22e70a7b9c8f0b4d7e80f42"}, ] -dbt-postgres = [ - {file = "dbt-postgres-1.0.6.tar.gz", hash = "sha256:f560ab7178e19990b9d1e5d4787a9f5c7104708a0bf09b8693548723b1d9dfc2"}, - {file = "dbt_postgres-1.0.6-py3-none-any.whl", hash = "sha256:3cf9d76d87768f7e398c86ade6c5be7fa1a3984384beb3a63a7c0b2008e6aec8"}, -] -dbt-redshift = [ - {file = "dbt-redshift-1.0.1.tar.gz", hash = "sha256:1e45d2948313a588d54d7b59354e7850a969cf2aafb4d3581f3a733cb0170e68"}, - {file = "dbt_redshift-1.0.1-py3-none-any.whl", hash = "sha256:1e5219d67c6c7a52235c46c7ca559b118ac7a5e1e62e6b3138eaa1cb67597751"}, -] decopatch = [ {file = "decopatch-1.4.10-py2.py3-none-any.whl", hash = "sha256:e151f7f93de2b1b3fd3f3272dcc7cefd1a69f68ec1c2d8e288ecd9deb36dc5f7"}, {file = "decopatch-1.4.10.tar.gz", hash = "sha256:957f49c93f4150182c23f8fb51d13bb3213e0f17a79e09c8cca7057598b55720"}, ] domdf-python-tools = [] -fire = [] flake8 = [] flake8-bugbear = [] flake8-builtins = [ @@ -1725,10 +1588,7 @@ grpcio-status = [ {file = "grpcio_status-1.43.0-py3-none-any.whl", hash = "sha256:9036b24f5769adafdc3e91d9434c20e9ede0b30f50cc6bff105c0f414bb9e0e0"}, ] hexbytes = [] -hologram = [ - {file = "hologram-0.0.14-py3-none-any.whl", hash = "sha256:2911b59115bebd0504eb089532e494fa22ac704989afe41371c5361780433bfe"}, - {file = "hologram-0.0.14.tar.gz", hash = "sha256:fd67bd069e4681e1d2a447df976c65060d7a90fee7f6b84d133fd9958db074ec"}, -] +hologram = [] idna = [] importlib-metadata = [ {file = "importlib_metadata-4.12.0-py3-none-any.whl", hash = "sha256:7401a975809ea1fdc658c3aa4f78cc2195a0e019c5cbc4c06122884e9ae80c23"}, @@ -1746,10 +1606,6 @@ jinja2 = [ {file = "Jinja2-2.11.3-py2.py3-none-any.whl", hash = "sha256:03e47ad063331dd6a3f04a43eddca8a966a26ba0c5b7207a9a9e4e08f1b29419"}, {file = "Jinja2-2.11.3.tar.gz", hash = "sha256:a6d58433de0ae800347cab1fa3043cebbabe8baa9d29e668f1c768cb87a333c6"}, ] -jmespath = [ - {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"}, - {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, -] json-logging = [ {file = "json-logging-1.4.1rc0.tar.gz", hash = "sha256:381e00495bbd619d09c8c3d1fdd72c843f7045797ab63b42cfec5f7961e5b3f6"}, {file = "json_logging-1.4.1rc0-py2.py3-none-any.whl", hash = "sha256:2b787c28f31fb4d8aabac16ac3816326031d92dd054bdabc9bbe68eb10864f77"}, @@ -1758,10 +1614,7 @@ jsonlines = [ {file = "jsonlines-2.0.0-py3-none-any.whl", hash = "sha256:bfb043d4e25fd894dca67b1f2adf014e493cb65d0f18b3a74a98bfcd97c3d983"}, {file = "jsonlines-2.0.0.tar.gz", hash = "sha256:6fdd03104c9a421a1ba587a121aaac743bf02d8f87fa9cdaa3b852249a241fe8"}, ] -jsonschema = [ - {file = "jsonschema-3.1.1-py2.py3-none-any.whl", hash = "sha256:94c0a13b4a0616458b42529091624e66700a17f847453e52279e35509a5b7631"}, - {file = "jsonschema-3.1.1.tar.gz", hash = "sha256:2fa0684276b6333ff3c0b1b27081f4b2305f0a36cf702a23db50edb141893c3f"}, -] +jsonschema = [] leather = [ {file = "leather-0.3.4-py2.py3-none-any.whl", hash = "sha256:5e741daee96e9f1e9e06081b8c8a10c4ac199301a0564cdd99b09df15b4603d2"}, {file = "leather-0.3.4.tar.gz", hash = "sha256:b43e21c8fa46b2679de8449f4d953c06418666dc058ce41055ee8a8d3bb40918"}, @@ -1917,7 +1770,10 @@ mypy-extensions = [ {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"}, ] natsort = [] -networkx = [] +networkx = [ + {file = "networkx-2.8.3-py3-none-any.whl", hash = "sha256:f151edac6f9b0cf11fecce93e236ac22b499bb9ff8d6f8393b9fef5ad09506cc"}, + {file = "networkx-2.8.3.tar.gz", hash = "sha256:67fab04a955a73eb660fe7bf281b6fa71a003bc6e23a92d2f6227654c5223dbe"}, +] numpy = [] packaging = [ {file = "packaging-21.3-py3-none-any.whl", hash = "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522"}, @@ -1927,6 +1783,7 @@ parsedatetime = [ {file = "parsedatetime-2.4-py2-none-any.whl", hash = "sha256:9ee3529454bf35c40a77115f5a596771e59e1aee8c53306f346c461b8e913094"}, {file = "parsedatetime-2.4.tar.gz", hash = "sha256:3d817c58fb9570d1eec1dd46fa9448cd644eeed4fb612684b02dfda3a79cb84b"}, ] +pathvalidate = [] pbr = [] pendulum = [ {file = "pendulum-2.1.2-cp27-cp27m-macosx_10_15_x86_64.whl", hash = "sha256:b6c352f4bd32dff1ea7066bd31ad0f71f8d8100b9ff709fb343f3b86cee43efe"}, @@ -2067,29 +1924,7 @@ pyparsing = [ {file = "pyparsing-3.0.9-py3-none-any.whl", hash = "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc"}, {file = "pyparsing-3.0.9.tar.gz", hash = "sha256:2b020ecf7d21b687f219b71ecad3631f644a47f01403fa1d1036b0c6416d70fb"}, ] -pyrsistent = [ - {file = "pyrsistent-0.18.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:df46c854f490f81210870e509818b729db4488e1f30f2a1ce1698b2295a878d1"}, - {file = "pyrsistent-0.18.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d45866ececf4a5fff8742c25722da6d4c9e180daa7b405dc0a2a2790d668c26"}, - {file = "pyrsistent-0.18.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4ed6784ceac462a7d6fcb7e9b663e93b9a6fb373b7f43594f9ff68875788e01e"}, - {file = "pyrsistent-0.18.1-cp310-cp310-win32.whl", hash = "sha256:e4f3149fd5eb9b285d6bfb54d2e5173f6a116fe19172686797c056672689daf6"}, - {file = "pyrsistent-0.18.1-cp310-cp310-win_amd64.whl", hash = "sha256:636ce2dc235046ccd3d8c56a7ad54e99d5c1cd0ef07d9ae847306c91d11b5fec"}, - {file = "pyrsistent-0.18.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e92a52c166426efbe0d1ec1332ee9119b6d32fc1f0bbfd55d5c1088070e7fc1b"}, - {file = "pyrsistent-0.18.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7a096646eab884bf8bed965bad63ea327e0d0c38989fc83c5ea7b8a87037bfc"}, - {file = "pyrsistent-0.18.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cdfd2c361b8a8e5d9499b9082b501c452ade8bbf42aef97ea04854f4a3f43b22"}, - {file = "pyrsistent-0.18.1-cp37-cp37m-win32.whl", hash = "sha256:7ec335fc998faa4febe75cc5268a9eac0478b3f681602c1f27befaf2a1abe1d8"}, - {file = "pyrsistent-0.18.1-cp37-cp37m-win_amd64.whl", hash = "sha256:6455fc599df93d1f60e1c5c4fe471499f08d190d57eca040c0ea182301321286"}, - {file = "pyrsistent-0.18.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:fd8da6d0124efa2f67d86fa70c851022f87c98e205f0594e1fae044e7119a5a6"}, - {file = "pyrsistent-0.18.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7bfe2388663fd18bd8ce7db2c91c7400bf3e1a9e8bd7d63bf7e77d39051b85ec"}, - {file = "pyrsistent-0.18.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0e3e1fcc45199df76053026a51cc59ab2ea3fc7c094c6627e93b7b44cdae2c8c"}, - {file = "pyrsistent-0.18.1-cp38-cp38-win32.whl", hash = "sha256:b568f35ad53a7b07ed9b1b2bae09eb15cdd671a5ba5d2c66caee40dbf91c68ca"}, - {file = "pyrsistent-0.18.1-cp38-cp38-win_amd64.whl", hash = "sha256:d1b96547410f76078eaf66d282ddca2e4baae8964364abb4f4dcdde855cd123a"}, - {file = "pyrsistent-0.18.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:f87cc2863ef33c709e237d4b5f4502a62a00fab450c9e020892e8e2ede5847f5"}, - {file = "pyrsistent-0.18.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6bc66318fb7ee012071b2792024564973ecc80e9522842eb4e17743604b5e045"}, - {file = "pyrsistent-0.18.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:914474c9f1d93080338ace89cb2acee74f4f666fb0424896fcfb8d86058bf17c"}, - {file = "pyrsistent-0.18.1-cp39-cp39-win32.whl", hash = "sha256:1b34eedd6812bf4d33814fca1b66005805d3640ce53140ab8bbb1e2651b0d9bc"}, - {file = "pyrsistent-0.18.1-cp39-cp39-win_amd64.whl", hash = "sha256:e24a828f57e0c337c8d8bb9f6b12f09dfdf0273da25fda9e314f0b684b415a07"}, - {file = "pyrsistent-0.18.1.tar.gz", hash = "sha256:d4d61f8b993a7255ba714df3aca52700f8125289f84f704cf80916517c46eb96"}, -] +pyrsistent = [] pytest = [ {file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"}, {file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"}, @@ -2158,16 +1993,11 @@ pyyaml = [ {file = "PyYAML-5.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:c20cfa2d49991c8b4147af39859b167664f2ad4561704ee74c1de03318e898db"}, {file = "PyYAML-5.4.1.tar.gz", hash = "sha256:607774cbba28732bfa802b54baa7484215f530991055bb562efbed5b2f20a45e"}, ] -randomname = [] requests = [ {file = "requests-2.28.1-py3-none-any.whl", hash = "sha256:8fefa2a1a1365bf5520aac41836fbee479da67864514bdb821f31ce07ce65349"}, {file = "requests-2.28.1.tar.gz", hash = "sha256:7c5599b102feddaa661c826c56ab4fee28bfd17f5abca1ebbe3e7f19d7c97983"}, ] rsa = [] -s3transfer = [ - {file = "s3transfer-0.6.0-py3-none-any.whl", hash = "sha256:06176b74f3a15f61f1b4f25a1fc29a4429040b7647133a463da8fa5bd28d5ecd"}, - {file = "s3transfer-0.6.0.tar.gz", hash = "sha256:2ed07d3866f523cc561bf4a00fc5535827981b117dd7876f036b0c1aca42c947"}, -] semver = [ {file = "semver-2.13.0-py2.py3-none-any.whl", hash = "sha256:ced8b23dceb22134307c1b8abfa523da14198793d9787ac838e70e29e77458d4"}, {file = "semver-2.13.0.tar.gz", hash = "sha256:fa0fe2722ee1c3f57eac478820c3a5ae2f624af8264cbdf9000c980ff7f75e3f"}, @@ -2244,12 +2074,8 @@ smmap = [ {file = "smmap-5.0.0-py3-none-any.whl", hash = "sha256:2aba19d6a040e78d8b09de5c57e96207b09ed71d8e55ce0959eeee6c8e190d94"}, {file = "smmap-5.0.0.tar.gz", hash = "sha256:c840e62059cd3be204b0c9c9f74be2c09d5648eddd4580d9314c3ecde0b30936"}, ] -sqlparse = [ - {file = "sqlparse-0.4.2-py3-none-any.whl", hash = "sha256:48719e356bb8b42991bdbb1e8b83223757b93789c00910a616a071910ca4a64d"}, - {file = "sqlparse-0.4.2.tar.gz", hash = "sha256:0c00730c74263a94e5a9919ade150dfc3b19c574389985446148402998287dae"}, -] +sqlparse = [] stevedore = [] -termcolor = [] text-unidecode = [ {file = "text-unidecode-1.3.tar.gz", hash = "sha256:bad6603bb14d279193107714b288be206cac565dfa49aa5b105294dd5c4aab93"}, {file = "text_unidecode-1.3-py2.py3-none-any.whl", hash = "sha256:1311f10e8b895935241623731c2ba64f4c455287888b18189350b67134a822e8"}, @@ -2273,11 +2099,7 @@ types-pyyaml = [] types-requests = [] types-simplejson = [] types-urllib3 = [] -typing-extensions = [ - {file = "typing_extensions-3.10.0.2-py2-none-any.whl", hash = "sha256:d8226d10bc02a29bcc81df19a26e56a9647f8b0a6d4a83924139f4a8b01f17b7"}, - {file = "typing_extensions-3.10.0.2-py3-none-any.whl", hash = "sha256:f1d25edafde516b146ecd0613dabcc61409817af4766fbbcfb8d1ad4ec441a34"}, - {file = "typing_extensions-3.10.0.2.tar.gz", hash = "sha256:49f75d16ff11f1cd258e1b988ccff82a3ca5570217d7ad8c5f48205dd99a677e"}, -] +typing-extensions = [] tzdata = [] urllib3 = [] werkzeug = [ diff --git a/pyproject.toml b/pyproject.toml index 83181dc616..467552a128 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,18 +47,18 @@ google-cloud-bigquery-storage = {version = "^2.13.0", optional = true} pyarrow = {version = "^8.0.0", optional = true} GitPython = {version = "^3.1.26", optional = true} -dbt-core = {version = "1.0.6", optional = true} -dbt-redshift = {version = "1.0.1", optional = true} -dbt-bigquery = {version = "1.0.0", optional = true} -randomname = "^0.1.5" +dbt-core = {version = ">=1.1.0,<1.2.0", optional = true} +dbt-redshift = {version = ">=1.0.0,<1.2.0", optional = true} +dbt-bigquery = {version = ">=1.0.0,<1.2.0", optional = true} tzdata = "^2022.1" tomlkit = "^0.11.3" asyncstdlib = "^3.10.5" +pathvalidate = "^2.5.2" [tool.poetry.dev-dependencies] pytest = "^6.2.4" -mypy = "0.971" +mypy = "0.982" flake8 = "^5.0.0" bandit = "^1.7.0" flake8-bugbear = "^22.0.0" @@ -75,6 +75,7 @@ types-python-dateutil = "^2.8.15" flake8-tidy-imports = "^4.8.0" flake8-encodings = "^0.5.0" flake8-builtins = "^1.5.3" +typing-extensions = "^4.4.0" [tool.poetry.extras] dbt = ["dbt-core", "GitPython", "dbt-redshift", "dbt-bigquery"] diff --git a/tests/.example.env b/tests/.example.env index 0a9a700dcf..c38ab0530d 100644 --- a/tests/.example.env +++ b/tests/.example.env @@ -4,14 +4,14 @@ DEFAULT_DATASET=carbon_bot_3 -GCP__PROJECT_ID=chat-analytics-317513 -GCP__PRIVATE_KEY="-----BEGIN PRIVATE KEY----- +CREDENTIALS__PROJECT_ID=chat-analytics-317513 +CREDENTIALS__PRIVATE_KEY="-----BEGIN PRIVATE KEY----- paste key here -----END PRIVATE KEY----- " -CLIENT_EMAIL=loader@chat-analytics-317513.iam.gserviceaccount.com +CREDENTIALS__CLIENT_EMAIL=loader@chat-analytics-317513.iam.gserviceaccount.com -PG__DBNAME=chat_analytics_rasa -PG__USER=loader -PG__HOST=3.73.90.3 -PG__PASSWORD=set-me-up \ No newline at end of file +CREDENTIALS__DBNAME=chat_analytics_rasa +CREDENTIALS__USER=loader +CREDENTIALS__HOST=3.73.90.3 +CREDENTIALS__PASSWORD=set-me-up \ No newline at end of file diff --git a/tests/cases.py b/tests/cases.py index b17f666d0c..8da604528b 100644 --- a/tests/cases.py +++ b/tests/cases.py @@ -15,7 +15,7 @@ "big_decimal": Decimal("115792089237316195423570985008687907853269984665640564039457584007913129639935.1"), "datetime": pendulum.parse("2005-04-02T20:37:37.358236Z"), "date": pendulum.parse("2022-02-02").date(), - "uuid": UUID(_UUID), + # "uuid": UUID(_UUID), "hexbytes": HexBytes("0x2137"), "bytes": b'2137', "wei": Wei.from_int256(2137, decimals=2) @@ -26,8 +26,8 @@ "decimal": "decimal", "big_decimal": "decimal", "datetime": "timestamp", - "date": "text", - "uuid": "text", + "date": "timestamp", + # "uuid": "text", "hexbytes": "binary", "bytes": "binary", "wei": "wei" diff --git a/tests/common/cases/configuration/.dlt/config.toml b/tests/common/cases/configuration/.dlt/config.toml new file mode 100644 index 0000000000..13e287065f --- /dev/null +++ b/tests/common/cases/configuration/.dlt/config.toml @@ -0,0 +1,27 @@ +api_type="REST" + +api.url="http" +api.port=1024 + +[api.params] +param1="a" +param2="b" + +[typecheck] +str_val="test string" +int_val=12345 +bool_val=true +list_val=[1, "2", [3]] +dict_val={'a'=1, "b"="2"} +float_val=1.18927 +tuple_val=[1, 2, {1="complicated dicts allowed in literal eval"}] +COMPLEX_VAL={"_"= [1440, ["*"], []], "change-email"= [560, ["*"], []]} +date_val=1979-05-27T07:32:00-08:00 +dec_val="22.38" # always use text to pass decimals +bytes_val="0x48656c6c6f20576f726c6421" # always use text to pass hex value that should be converted to bytes +any_val="function() {}" +none_val="none" +sequence_val=["A", "B", "KAPPA"] +gen_list_val=["C", "Z", "N"] +mapping_val={"FL"=1, "FR"={"1"=2}} +mutable_mapping_val={"str"="str"} diff --git a/tests/common/cases/configuration/.dlt/secrets.toml b/tests/common/cases/configuration/.dlt/secrets.toml new file mode 100644 index 0000000000..42e11d46dd --- /dev/null +++ b/tests/common/cases/configuration/.dlt/secrets.toml @@ -0,0 +1,70 @@ +secret_value="2137" +api.port=1023 + +# holds a literal string that can be parsed as gcp credentials +source.credentials=''' +{ + "type": "service_account", + "project_id": "mock-project-id-source.credentials", + "private_key_id": "62c1f8f00836dec27c8d96d1c0836df2c1f6bce4", + "private_key": "-----BEGIN PRIVATE KEY-----\nMIIEuwIBADANBgkqhkiG9w0BAQEFAASCBKUwggShAgEAAoIBAQCNEN0bL39HmD+S\n...\n-----END PRIVATE KEY-----\n", + "client_email": "loader@a7513.iam.gserviceaccount.com", + "client_id": "114701312674477307596", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/loader%40mock-project-id-2.iam.gserviceaccount.com", + "file_upload_timeout": 819872989 + } +''' + +[credentials] +secret_value="2137" +"project_id"="mock-project-id-credentials" + +[gcp_storage] +"project_id"="mock-project-id-gcp-storage" +"private_key"="-----BEGIN PRIVATE KEY-----\nMIIEuwIBADANBgkqhkiG9w0BAQEFAASCBKUwggShAgEAAoIBAQCNEN0bL39HmD+S\n....\n-----END PRIVATE KEY-----\n" +"client_email"="loader@a7513.iam.gserviceaccount.com" + +[destination.redshift.credentials] +dbname="destination.redshift.credentials" +user="loader" +host="3.73.90.3" +password="set-me-up" + +[destination.credentials] +"type"="service_account" +"project_id"="mock-project-id-destination.credentials" +"private_key_id"="62c1f8f00836dec27c8d96d1c0836df2c1f6bce4" +"private_key"="-----BEGIN PRIVATE KEY-----\nMIIEuwIBADANBgkqhkiG9w0BAQEFAASCBKUwggShAgEAAoIBAQCNEN0bL39HmD+S\n....\n-----END PRIVATE KEY-----\n" +"client_email"="loader@a7513.iam.gserviceaccount.com" +"client_id"="114701312674477307596" +"auth_uri"="https://accounts.google.com/o/oauth2/auth" +"token_uri"="https://oauth2.googleapis.com/token" +"auth_provider_x509_cert_url"="https://www.googleapis.com/oauth2/v1/certs" +"client_x509_cert_url"="https://www.googleapis.com/robot/v1/metadata/x509/loader%40mock-project-id-1.iam.gserviceaccount.com" + +[destination.bigquery] +"type"="service_account" +"project_id"="mock-project-id-destination.bigquery" +"private_key_id"="62c1f8f00836dec27c8d96d1c0836df2c1f6bce4" +"private_key"="-----BEGIN PRIVATE KEY-----\nMIIEuwIBADANBgkqhkiG9w0BAQEFAASCBKUwggShAgEAAoIBAQCNEN0bL39HmD+S\n....\n-----END PRIVATE KEY-----\n" +"client_email"="loader@a7513.iam.gserviceaccount.com" +"client_id"="114701312674477307596" +"auth_uri"="https://accounts.google.com/o/oauth2/auth" +"token_uri"="https://oauth2.googleapis.com/token" +"auth_provider_x509_cert_url"="https://www.googleapis.com/oauth2/v1/certs" +"client_x509_cert_url"="https://www.googleapis.com/robot/v1/metadata/x509/loader%40mock-project-id-1.iam.gserviceaccount.com" + +[destination.bigquery.credentials] +"type"="service_account" +"project_id"="mock-project-id-destination.bigquery.credentials" +"private_key_id"="62c1f8f00836dec27c8d96d1c0836df2c1f6bce4" +"private_key"="-----BEGIN PRIVATE KEY-----\nMIIEuwIBADANBgkqhkiG9w0BAQEFAASCBKUwggShAgEAAoIBAQCNEN0bL39HmD+S\n....\n-----END PRIVATE KEY-----\n" +"client_email"="loader@a7513.iam.gserviceaccount.com" +"client_id"="114701312674477307596" +"auth_uri"="https://accounts.google.com/o/oauth2/auth" +"token_uri"="https://oauth2.googleapis.com/token" +"auth_provider_x509_cert_url"="https://www.googleapis.com/oauth2/v1/certs" +"client_x509_cert_url"="https://www.googleapis.com/robot/v1/metadata/x509/loader%40mock-project-id-1.iam.gserviceaccount.com" diff --git a/tests/common/cases/schemas/ev1/event_schema.7z b/tests/common/cases/schemas/ev1/event.schema.7z similarity index 100% rename from tests/common/cases/schemas/ev1/event_schema.7z rename to tests/common/cases/schemas/ev1/event.schema.7z diff --git a/tests/common/cases/schemas/ev1/event_schema.json b/tests/common/cases/schemas/ev1/event.schema.json similarity index 100% rename from tests/common/cases/schemas/ev1/event_schema.json rename to tests/common/cases/schemas/ev1/event.schema.json diff --git a/tests/common/cases/schemas/ev1/model_schema.json b/tests/common/cases/schemas/ev1/model.schema.json similarity index 100% rename from tests/common/cases/schemas/ev1/model_schema.json rename to tests/common/cases/schemas/ev1/model.schema.json diff --git a/tests/common/cases/schemas/ev2/event_schema.json b/tests/common/cases/schemas/ev2/event.schema.json similarity index 100% rename from tests/common/cases/schemas/ev2/event_schema.json rename to tests/common/cases/schemas/ev2/event.schema.json diff --git a/tests/common/cases/schemas/rasa/event_schema.json b/tests/common/cases/schemas/rasa/event.schema.json similarity index 100% rename from tests/common/cases/schemas/rasa/event_schema.json rename to tests/common/cases/schemas/rasa/event.schema.json diff --git a/tests/common/cases/schemas/rasa/model_schema.json b/tests/common/cases/schemas/rasa/model.schema.json similarity index 100% rename from tests/common/cases/schemas/rasa/model_schema.json rename to tests/common/cases/schemas/rasa/model.schema.json diff --git a/tests/common/configuration/__init__.py b/tests/common/configuration/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/common/configuration/test_configuration.py b/tests/common/configuration/test_configuration.py new file mode 100644 index 0000000000..711bf4ee83 --- /dev/null +++ b/tests/common/configuration/test_configuration.py @@ -0,0 +1,651 @@ +import pytest +import datetime # noqa: I251 +from typing import Any, Dict, List, Mapping, MutableMapping, NewType, Optional, Type + +from dlt.common import pendulum, Decimal, Wei +from dlt.common.utils import custom_environ +from dlt.common.typing import TSecretValue, extract_inner_type +from dlt.common.configuration.exceptions import ConfigFieldMissingTypeHintException, ConfigFieldTypeHintNotSupported, InvalidInitialValue, LookupTrace, ValueNotSecretException +from dlt.common.configuration import configspec, ConfigFieldMissingException, ConfigValueCannotBeCoercedException, resolve +from dlt.common.configuration.specs import BaseConfiguration, RunConfiguration +from dlt.common.configuration.specs.base_configuration import is_valid_hint +from dlt.common.configuration.providers import environ as environ_provider, toml + +from tests.utils import preserve_environ, add_config_dict_to_env +from tests.common.configuration.utils import MockProvider, CoercionTestConfiguration, COERCIONS, WithCredentialsConfiguration, WrongConfiguration, SecretConfiguration, NamespacedConfiguration, environment, mock_provider + +INVALID_COERCIONS = { + # 'STR_VAL': 'test string', # string always OK + 'int_val': "a12345", + 'bool_val': "not_bool", # bool overridden by string - that is the most common problem + 'list_val': {2: 1, "2": 3.0}, + 'dict_val': "{'a': 1, 'b', '2'}", + 'bytes_val': 'Hello World!', + 'float_val': "invalid", + "tuple_val": "{1:2}", + "date_val": "01 May 2022", + "dec_val": True +} + +EXCEPTED_COERCIONS = { + # allows to use int for float + 'float_val': 10, + # allows to use float for str + 'str_val': 10.0 +} + +COERCED_EXCEPTIONS = { + # allows to use int for float + 'float_val': 10.0, + # allows to use float for str + 'str_val': "10.0" +} + + +@configspec +class VeryWrongConfiguration(WrongConfiguration): + pipeline_name: str = "Some Name" + str_val: str = "" + int_val: int = None + log_color: str = "1" # type: ignore + + +@configspec +class ConfigurationWithOptionalTypes(RunConfiguration): + pipeline_name: str = "Some Name" + + str_val: Optional[str] = None + int_val: Optional[int] = None + bool_val: bool = True + + +@configspec +class ProdConfigurationWithOptionalTypes(ConfigurationWithOptionalTypes): + prod_val: str = "prod" + + +@configspec +class MockProdConfiguration(RunConfiguration): + pipeline_name: str = "comp" + + +@configspec(init=True) +class FieldWithNoDefaultConfiguration(RunConfiguration): + no_default: str + + +@configspec(init=True) +class InstrumentedConfiguration(BaseConfiguration): + head: str + tube: List[str] + heels: str + + def to_native_representation(self) -> Any: + return self.head + ">" + ">".join(self.tube) + ">" + self.heels + + def from_native_representation(self, native_value: Any) -> None: + if not isinstance(native_value, str): + raise ValueError(native_value) + parts = native_value.split(">") + self.head = parts[0] + self.heels = parts[-1] + self.tube = parts[1:-1] + + def check_integrity(self) -> None: + if self.head > self.heels: + raise RuntimeError("Head over heels") + + +@configspec +class EmbeddedConfiguration(BaseConfiguration): + default: str + instrumented: InstrumentedConfiguration + namespaced: NamespacedConfiguration + + +@configspec +class EmbeddedOptionalConfiguration(BaseConfiguration): + instrumented: Optional[InstrumentedConfiguration] + + +@configspec +class EmbeddedSecretConfiguration(BaseConfiguration): + secret: SecretConfiguration + + +LongInteger = NewType("LongInteger", int) +FirstOrderStr = NewType("FirstOrderStr", str) +SecondOrderStr = NewType("SecondOrderStr", FirstOrderStr) + + +def test_initial_config_state() -> None: + assert BaseConfiguration.__is_resolved__ is False + assert BaseConfiguration.__namespace__ is None + c = BaseConfiguration() + assert c.__is_resolved__ is False + assert c.is_resolved() is False + # base configuration has no resolvable fields so is never partial + assert c.is_partial() is False + + +def test_set_initial_config_value(environment: Any) -> None: + # set from init method + c = resolve.resolve_configuration(InstrumentedConfiguration(head="h", tube=["a", "b"], heels="he")) + assert c.to_native_representation() == "h>a>b>he" + # set from native form + c = resolve.resolve_configuration(InstrumentedConfiguration(), initial_value="h>a>b>he") + assert c.head == "h" + assert c.tube == ["a", "b"] + assert c.heels == "he" + # set from dictionary + c = resolve.resolve_configuration(InstrumentedConfiguration(), initial_value={"head": "h", "tube": ["tu", "be"], "heels": "xhe"}) + assert c.to_native_representation() == "h>tu>be>xhe" + + +def test_initial_native_representation_skips_resolve(environment: Any) -> None: + c = InstrumentedConfiguration() + # mock namespace to enable looking for initials in provider + c.__namespace__ = "ins" + # explicit initial does not skip resolve + environment["INS__HEELS"] = "xhe" + c = resolve.resolve_configuration(c, initial_value="h>a>b>he") + assert c.heels == "xhe" + + # now put the whole native representation in env + environment["INS"] = "h>a>b>he" + c = InstrumentedConfiguration() + c.__namespace__ = "ins" + c = resolve.resolve_configuration(c, initial_value="h>a>b>uhe") + assert c.heels == "he" + + +def test_query_initial_config_value_if_config_namespace(environment: Any) -> None: + c = InstrumentedConfiguration(head="h", tube=["a", "b"], heels="he") + # mock the __namespace__ to enable the query + c.__namespace__ = "snake" + # provide the initial value + environment["SNAKE"] = "h>tu>be>xhe" + c = resolve.resolve_configuration(c) + # check if the initial value loaded + assert c.heels == "xhe" + + +def test_invalid_initial_config_value() -> None: + # 2137 cannot be parsed and also is not a dict that can initialize the fields + with pytest.raises(InvalidInitialValue) as py_ex: + resolve.resolve_configuration(InstrumentedConfiguration(), initial_value=2137) + assert py_ex.value.spec is InstrumentedConfiguration + assert py_ex.value.initial_value_type is int + + +def test_check_integrity(environment: Any) -> None: + with pytest.raises(RuntimeError): + # head over hells + resolve.resolve_configuration(InstrumentedConfiguration(), initial_value="he>a>b>h") + + +def test_embedded_config(environment: Any) -> None: + # resolve all embedded config, using initial value for instrumented config and initial dict for namespaced config + C = resolve.resolve_configuration(EmbeddedConfiguration(), initial_value={"default": "set", "instrumented": "h>tu>be>xhe", "namespaced": {"password": "pwd"}}) + assert C.default == "set" + assert C.instrumented.to_native_representation() == "h>tu>be>xhe" + assert C.namespaced.password == "pwd" + + # resolve but providing values via env + with custom_environ({"INSTRUMENTED": "h>tu>u>be>xhe", "NAMESPACED__PASSWORD": "passwd", "DEFAULT": "DEF"}): + C = resolve.resolve_configuration(EmbeddedConfiguration()) + assert C.default == "DEF" + assert C.instrumented.to_native_representation() == "h>tu>u>be>xhe" + assert C.namespaced.password == "passwd" + + # resolve partial, partial is passed to embedded + C = resolve.resolve_configuration(EmbeddedConfiguration(), accept_partial=True) + assert not C.__is_resolved__ + assert not C.namespaced.__is_resolved__ + assert not C.instrumented.__is_resolved__ + + # some are partial, some are not + with custom_environ({"NAMESPACED__PASSWORD": "passwd"}): + C = resolve.resolve_configuration(EmbeddedConfiguration(), accept_partial=True) + assert not C.__is_resolved__ + assert C.namespaced.__is_resolved__ + assert not C.instrumented.__is_resolved__ + + # single integrity error fails all the embeds + with custom_environ({"INSTRUMENTED": "he>tu>u>be>h"}): + with pytest.raises(RuntimeError): + resolve.resolve_configuration(EmbeddedConfiguration(), initial_value={"default": "set", "namespaced": {"password": "pwd"}}) + + # part via env part via initial values + with custom_environ({"INSTRUMENTED": "h>tu>u>be>he"}): + C = resolve.resolve_configuration(EmbeddedConfiguration(), initial_value={"default": "set", "namespaced": {"password": "pwd"}}) + assert C.instrumented.to_native_representation() == "h>tu>u>be>he" + + +def test_provider_values_over_initial(environment: Any) -> None: + with custom_environ({"INSTRUMENTED": "h>tu>u>be>he"}): + C = resolve.resolve_configuration(EmbeddedConfiguration(), initial_value={"instrumented": "h>tu>be>xhe"}, accept_partial=True) + assert C.instrumented.to_native_representation() == "h>tu>u>be>he" + # parent configuration is not resolved + assert not C.is_resolved() + assert C.is_partial() + # but embedded is + assert C.instrumented.__is_resolved__ + assert C.instrumented.is_resolved() + assert not C.instrumented.is_partial() + + +def test_run_configuration_gen_name(environment: Any) -> None: + C = resolve.resolve_configuration(RunConfiguration()) + assert C.pipeline_name.startswith("dlt_") + + +def test_configuration_is_mutable_mapping(environment: Any) -> None: + + + @configspec + class _SecretCredentials(RunConfiguration): + pipeline_name: Optional[str] = "secret" + secret_value: TSecretValue = None + + + # configurations provide full MutableMapping support + # here order of items in dict matters + expected_dict = { + 'pipeline_name': 'secret', + 'sentry_dsn': None, + 'prometheus_port': None, + 'log_format': '{asctime}|[{levelname:<21}]|{process}|{name}|{filename}|{funcName}:{lineno}|{message}', + 'log_level': 'DEBUG', + 'request_timeout': (15, 300), + 'config_files_storage_path': '_storage/config/%s', + "secret_value": None + } + assert dict(_SecretCredentials()) == expected_dict + + environment["SECRET_VALUE"] = "secret" + c = resolve.resolve_configuration(_SecretCredentials()) + expected_dict["secret_value"] = "secret" + assert dict(c) == expected_dict + + # check mutable mapping type + assert isinstance(c, MutableMapping) + assert isinstance(c, Mapping) + assert not isinstance(c, Dict) + + # check view ops + assert c.keys() == expected_dict.keys() + assert len(c) == len(expected_dict) + assert c.items() == expected_dict.items() + assert list(c.values()) == list(expected_dict.values()) + for key in c: + assert c[key] == expected_dict[key] + # version is present as attr but not present in dict + assert hasattr(c, "__is_resolved__") + assert hasattr(c, "__namespace__") + + # set ops + # update supported and non existing attributes are ignored + c.update({"pipeline_name": "old pipe", "__version": "1.1.1"}) + assert c.pipeline_name == "old pipe" == c["pipeline_name"] + + # delete is not supported + with pytest.raises(KeyError): + del c["pipeline_name"] + + with pytest.raises(KeyError): + c.pop("pipeline_name", None) + + # setting supported + c["pipeline_name"] = "new pipe" + assert c.pipeline_name == "new pipe" == c["pipeline_name"] + with pytest.raises(KeyError): + c["unknown_prop"] = "unk" + + # also on new instance + c = SecretConfiguration() + with pytest.raises(KeyError): + c["unknown_prop"] = "unk" + + +def test_fields_with_no_default_to_null(environment: Any) -> None: + # fields with no default are promoted to class attrs with none + assert FieldWithNoDefaultConfiguration.no_default is None + assert FieldWithNoDefaultConfiguration().no_default is None + + +def test_init_method_gen(environment: Any) -> None: + C = FieldWithNoDefaultConfiguration(no_default="no_default", sentry_dsn="SENTRY") + assert C.no_default == "no_default" + assert C.sentry_dsn == "SENTRY" + + +def test_multi_derivation_defaults(environment: Any) -> None: + + @configspec + class MultiConfiguration(MockProdConfiguration, ConfigurationWithOptionalTypes, NamespacedConfiguration): + pass + + # apparently dataclasses set default in reverse mro so MockProdConfiguration overwrites + C = MultiConfiguration() + assert C.pipeline_name == MultiConfiguration.pipeline_name == "comp" + # but keys are ordered in MRO so password from NamespacedConfiguration goes first + keys = list(C.keys()) + assert keys[0] == "password" + assert keys[-1] == "bool_val" + assert C.__namespace__ == "DLT_TEST" + + +def test_raises_on_unresolved_field(environment: Any) -> None: + # via make configuration + with pytest.raises(ConfigFieldMissingException) as cf_missing_exc: + resolve.resolve_configuration(WrongConfiguration()) + assert cf_missing_exc.value.spec_name == "WrongConfiguration" + assert "NoneConfigVar" in cf_missing_exc.value.traces + # has only one trace + trace = cf_missing_exc.value.traces["NoneConfigVar"] + assert len(trace) == 3 + assert trace[0] == LookupTrace("Environment Variables", [], "NONECONFIGVAR", None) + assert trace[1] == LookupTrace("Pipeline secrets.toml", [], "NoneConfigVar", None) + assert trace[2] == LookupTrace("Pipeline config.toml", [], "NoneConfigVar", None) + + +def test_raises_on_many_unresolved_fields(environment: Any) -> None: + # via make configuration + with pytest.raises(ConfigFieldMissingException) as cf_missing_exc: + resolve.resolve_configuration(CoercionTestConfiguration()) + assert cf_missing_exc.value.spec_name == "CoercionTestConfiguration" + # get all fields that must be set + val_fields = [f for f in CoercionTestConfiguration().get_resolvable_fields() if f.lower().endswith("_val")] + traces = cf_missing_exc.value.traces + assert len(traces) == len(val_fields) + for tr_field, exp_field in zip(traces, val_fields): + assert len(traces[tr_field]) == 3 + assert traces[tr_field][0] == LookupTrace("Environment Variables", [], environ_provider.EnvironProvider.get_key_name(exp_field), None) + assert traces[tr_field][1] == LookupTrace("Pipeline secrets.toml", [], toml.TomlProvider.get_key_name(exp_field), None) + assert traces[tr_field][2] == LookupTrace("Pipeline config.toml", [], toml.TomlProvider.get_key_name(exp_field), None) + + +def test_accepts_optional_missing_fields(environment: Any) -> None: + # ConfigurationWithOptionalTypes has values for all non optional fields present + C = ConfigurationWithOptionalTypes() + assert not C.is_partial() + # make optional config + resolve.resolve_configuration(ConfigurationWithOptionalTypes()) + # make config with optional values + resolve.resolve_configuration(ProdConfigurationWithOptionalTypes(), initial_value={"int_val": None}) + # make config with optional embedded config + C = resolve.resolve_configuration(EmbeddedOptionalConfiguration()) + # embedded config was not fully resolved + assert not C.instrumented.__is_resolved__ + assert not C.instrumented.is_resolved() + assert C.instrumented.is_partial() + + +def test_find_all_keys() -> None: + keys = VeryWrongConfiguration().get_resolvable_fields() + # assert hints and types: LOG_COLOR had it hint overwritten in derived class + assert set({'str_val': str, 'int_val': int, 'NoneConfigVar': str, 'log_color': str}.items()).issubset(keys.items()) + + +def test_coercion_to_hint_types(environment: Any) -> None: + add_config_dict_to_env(COERCIONS) + + C = CoercionTestConfiguration() + resolve._resolve_config_fields(C, explicit_namespaces=(), embedded_namespaces=(), accept_partial=False) + + for key in COERCIONS: + assert getattr(C, key) == COERCIONS[key] + + +def test_values_serialization() -> None: + # test tuple + t_tuple = (1, 2, 3, "A") + v = resolve.serialize_value(t_tuple) + assert v == "(1, 2, 3, 'A')" # literal serialization + assert resolve.deserialize_value("K", v, tuple) == t_tuple + + # test list + t_list = ["a", 3, True] + v = resolve.serialize_value(t_list) + assert v == '["a", 3, true]' # json serialization + assert resolve.deserialize_value("K", v, list) == t_list + + # test datetime + t_date = pendulum.now() + v = resolve.serialize_value(t_date) + assert resolve.deserialize_value("K", v, datetime.datetime) == t_date + + # test wei + t_wei = Wei.from_int256(10**16, decimals=18) + v = resolve.serialize_value(t_wei) + assert v == "0.01" + # can be deserialized into + assert resolve.deserialize_value("K", v, float) == 0.01 + assert resolve.deserialize_value("K", v, Decimal) == Decimal("0.01") + assert resolve.deserialize_value("K", v, Wei) == Wei("0.01") + + +def test_invalid_coercions(environment: Any) -> None: + C = CoercionTestConfiguration() + add_config_dict_to_env(INVALID_COERCIONS) + for key, value in INVALID_COERCIONS.items(): + try: + resolve._resolve_config_fields(C, explicit_namespaces=(), embedded_namespaces=(), accept_partial=False) + except ConfigValueCannotBeCoercedException as coerc_exc: + # must fail exactly on expected value + if coerc_exc.field_name != key: + raise + # overwrite with valid value and go to next env + environment[key.upper()] = resolve.serialize_value(COERCIONS[key]) + continue + raise AssertionError("%s was coerced with %s which is invalid type" % (key, value)) + + +def test_excepted_coercions(environment: Any) -> None: + C = CoercionTestConfiguration() + add_config_dict_to_env(COERCIONS) + add_config_dict_to_env(EXCEPTED_COERCIONS, overwrite_keys=True) + resolve._resolve_config_fields(C, explicit_namespaces=(), embedded_namespaces=(), accept_partial=False) + for key in EXCEPTED_COERCIONS: + assert getattr(C, key) == COERCED_EXCEPTIONS[key] + + +def test_config_with_unsupported_types_in_hints(environment: Any) -> None: + with pytest.raises(ConfigFieldTypeHintNotSupported): + + @configspec + class InvalidHintConfiguration(BaseConfiguration): + tuple_val: tuple = None # type: ignore + set_val: set = None # type: ignore + InvalidHintConfiguration() + + +def test_config_with_no_hints(environment: Any) -> None: + with pytest.raises(ConfigFieldMissingTypeHintException): + + @configspec + class NoHintConfiguration(BaseConfiguration): + tuple_val = None + NoHintConfiguration() + + +def test_resolve_configuration(environment: Any) -> None: + # fill up configuration + environment["NONECONFIGVAR"] = "1" + C = resolve.resolve_configuration(WrongConfiguration()) + assert C.__is_resolved__ + assert C.NoneConfigVar == "1" + + +def test_dataclass_instantiation(environment: Any) -> None: + # resolve_configuration works on instances of dataclasses and types are not modified + environment['SECRET_VALUE'] = "1" + C = resolve.resolve_configuration(SecretConfiguration()) + # auto derived type holds the value + assert C.secret_value == "1" + # base type is untouched + assert SecretConfiguration.secret_value is None + + +def test_initial_values(environment: Any) -> None: + # initial values will be overridden from env + environment["PIPELINE_NAME"] = "env name" + environment["CREATED_VAL"] = "12837" + # set initial values and allow partial config + C = resolve.resolve_configuration(CoercionTestConfiguration(), + initial_value={"pipeline_name": "initial name", "none_val": type(environment), "created_val": 878232, "bytes_val": b"str"}, + accept_partial=True + ) + # from env + assert C.pipeline_name == "env name" + # from initial + assert C.bytes_val == b"str" + assert C.none_val == type(environment) + # new prop overridden from env + assert environment["CREATED_VAL"] == "12837" + + +def test_accept_partial(environment: Any) -> None: + # modify original type + WrongConfiguration.NoneConfigVar = None + # that None value will be present in the instance + C = resolve.resolve_configuration(WrongConfiguration(), accept_partial=True) + assert C.NoneConfigVar is None + # partial resolution + assert not C.__is_resolved__ + assert C.is_partial() + + +def test_coercion_rules() -> None: + with pytest.raises(ConfigValueCannotBeCoercedException): + coerce_single_value("key", "some string", int) + assert coerce_single_value("key", "some string", str) == "some string" + # Optional[str] has type object, mypy will never work properly... + assert coerce_single_value("key", "some string", Optional[str]) == "some string" # type: ignore + + assert coerce_single_value("key", "234", int) == 234 + assert coerce_single_value("key", "234", Optional[int]) == 234 # type: ignore + + # check coercions of NewTypes + assert coerce_single_value("key", "test str X", FirstOrderStr) == "test str X" + assert coerce_single_value("key", "test str X", Optional[FirstOrderStr]) == "test str X" # type: ignore + assert coerce_single_value("key", "test str X", Optional[SecondOrderStr]) == "test str X" # type: ignore + assert coerce_single_value("key", "test str X", SecondOrderStr) == "test str X" + assert coerce_single_value("key", "234", LongInteger) == 234 + assert coerce_single_value("key", "234", Optional[LongInteger]) == 234 # type: ignore + # this coercion should fail + with pytest.raises(ConfigValueCannotBeCoercedException): + coerce_single_value("key", "some string", LongInteger) + with pytest.raises(ConfigValueCannotBeCoercedException): + coerce_single_value("key", "some string", Optional[LongInteger]) # type: ignore + + +def test_is_valid_hint() -> None: + assert is_valid_hint(Any) is True + assert is_valid_hint(Optional[Any]) is True + assert is_valid_hint(RunConfiguration) is True + assert is_valid_hint(Optional[RunConfiguration]) is True + assert is_valid_hint(TSecretValue) is True + assert is_valid_hint(Optional[TSecretValue]) is True + # in case of generics, origin will be used and args are not checked + assert is_valid_hint(MutableMapping[TSecretValue, Any]) is True + # this is valid (args not checked) + assert is_valid_hint(MutableMapping[TSecretValue, ConfigValueCannotBeCoercedException]) is True + assert is_valid_hint(Wei) is True + # any class type, except deriving from BaseConfiguration is wrong type + assert is_valid_hint(ConfigFieldMissingException) is False + + +def test_configspec_auto_base_config_derivation() -> None: + + @configspec(init=True) + class AutoBaseDerivationConfiguration: + auto: str + + assert issubclass(AutoBaseDerivationConfiguration, BaseConfiguration) + assert hasattr(AutoBaseDerivationConfiguration, "auto") + + assert AutoBaseDerivationConfiguration().auto is None + assert AutoBaseDerivationConfiguration(auto="auto").auto == "auto" + assert AutoBaseDerivationConfiguration(auto="auto").get_resolvable_fields() == {"auto": str} + # we preserve original module + assert AutoBaseDerivationConfiguration.__module__ == __name__ + assert not hasattr(BaseConfiguration, "auto") + + +def test_secret_value_not_secret_provider(mock_provider: MockProvider) -> None: + mock_provider.value = "SECRET" + + # TSecretValue will fail + with pytest.raises(ValueNotSecretException) as py_ex: + resolve.resolve_configuration(SecretConfiguration(), namespaces=("mock",)) + assert py_ex.value.provider_name == "Mock Provider" + assert py_ex.value.key == "-secret_value" + + # anything derived from CredentialsConfiguration will fail + with pytest.raises(ValueNotSecretException) as py_ex: + resolve.resolve_configuration(WithCredentialsConfiguration(), namespaces=("mock",)) + assert py_ex.value.provider_name == "Mock Provider" + assert py_ex.value.key == "-credentials" + + +def test_do_not_resolve_twice(environment: Any) -> None: + environment["SECRET_VALUE"] = "password" + c = resolve.resolve_configuration(SecretConfiguration()) + assert c.secret_value == "password" + c2 = SecretConfiguration() + c2.secret_value = "other" + c2.__is_resolved__ = True + assert c2.is_resolved() + # will not overwrite with env + c3 = resolve.resolve_configuration(c2) + assert c3.secret_value == "other" + assert c3 is c2 + # make it not resolved + c2.__is_resolved__ = False + c4 = resolve.resolve_configuration(c2) + assert c4.secret_value == "password" + assert c2 is c3 is c4 + # also c is resolved so + c.secret_value = "else" + assert resolve.resolve_configuration(c).secret_value == "else" + + +def test_do_not_resolve_embedded(environment: Any) -> None: + environment["SECRET__SECRET_VALUE"] = "password" + c = resolve.resolve_configuration(EmbeddedSecretConfiguration()) + assert c.secret.secret_value == "password" + c2 = SecretConfiguration() + c2.secret_value = "other" + c2.__is_resolved__ = True + embed_c = EmbeddedSecretConfiguration() + embed_c.secret = c2 + embed_c2 = resolve.resolve_configuration(embed_c) + assert embed_c2.secret.secret_value == "other" + assert embed_c2.secret is c2 + + +def test_last_resolve_exception(environment: Any) -> None: + # partial will set the ConfigEntryMissingException + c = resolve.resolve_configuration(EmbeddedConfiguration(), accept_partial=True) + assert isinstance(c.__exception__, ConfigFieldMissingException) + # missing keys + c = SecretConfiguration() + with pytest.raises(ConfigFieldMissingException) as py_ex: + resolve.resolve_configuration(c) + assert c.__exception__ is py_ex.value + # but if ran again exception is cleared + environment["SECRET_VALUE"] = "password" + resolve.resolve_configuration(c) + assert c.__exception__ is None + # initial value + c = InstrumentedConfiguration() + with pytest.raises(InvalidInitialValue) as py_ex: + resolve.resolve_configuration(c, initial_value=2137) + assert c.__exception__ is py_ex.value + + +def coerce_single_value(key: str, value: str, hint: Type[Any]) -> Any: + hint = extract_inner_type(hint) + return resolve.deserialize_value(key, value, hint) diff --git a/tests/common/configuration/test_container.py b/tests/common/configuration/test_container.py new file mode 100644 index 0000000000..45f0e29738 --- /dev/null +++ b/tests/common/configuration/test_container.py @@ -0,0 +1,150 @@ +import pytest +from typing import Any, ClassVar, Literal + +from dlt.common.configuration import configspec +from dlt.common.configuration.providers.container import ContextProvider +from dlt.common.configuration.resolve import resolve_configuration +from dlt.common.configuration.specs import BaseConfiguration, ContainerInjectableContext +from dlt.common.configuration.container import Container +from dlt.common.configuration.exceptions import ContainerInjectableContextMangled, InvalidInitialValue, ContextDefaultCannotBeCreated +from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext + +from tests.utils import preserve_environ +from tests.common.configuration.utils import environment + + +@configspec(init=True) +class InjectableTestContext(ContainerInjectableContext): + current_value: str + + def from_native_representation(self, native_value: Any) -> None: + raise ValueError(native_value) + + +@configspec +class EmbeddedWithInjectableContext(BaseConfiguration): + injected: InjectableTestContext + + +@configspec +class NoDefaultInjectableContext(ContainerInjectableContext): + + can_create_default: ClassVar[bool] = False + + +@pytest.fixture() +def container() -> Container: + # erase singleton + Container._INSTANCE = None + return Container() + + +def test_singleton(container: Container) -> None: + # keep the old configurations list + container_configurations = container.contexts + + singleton = Container() + # make sure it is the same object + assert container is singleton + # that holds the same configurations dictionary + assert container_configurations is singleton.contexts + + +def test_get_default_injectable_config(container: Container) -> None: + injectable = container[InjectableTestContext] + assert injectable.current_value is None + assert isinstance(injectable, InjectableTestContext) + + +def test_raise_on_no_default_value(container: Container) -> None: + with pytest.raises(ContextDefaultCannotBeCreated): + container[NoDefaultInjectableContext] + + # ok when injected + with container.injectable_context(NoDefaultInjectableContext()) as injected: + assert container[NoDefaultInjectableContext] is injected + + +def test_container_injectable_context(container: Container) -> None: + with container.injectable_context(InjectableTestContext()) as current_config: + assert current_config.current_value is None + current_config.current_value = "TEST" + assert container[InjectableTestContext].current_value == "TEST" + assert container[InjectableTestContext] is current_config + + assert InjectableTestContext not in container + + +def test_container_injectable_context_restore(container: Container) -> None: + # this will create InjectableTestConfiguration + original = container[InjectableTestContext] + original.current_value = "ORIGINAL" + with container.injectable_context(InjectableTestContext()) as current_config: + current_config.current_value = "TEST" + # nested context is supported + with container.injectable_context(InjectableTestContext()) as inner_config: + assert inner_config.current_value is None + assert container[InjectableTestContext] is inner_config + assert container[InjectableTestContext] is current_config + + assert container[InjectableTestContext] is original + assert container[InjectableTestContext].current_value == "ORIGINAL" + + +def test_container_injectable_context_mangled(container: Container) -> None: + original = container[InjectableTestContext] + original.current_value = "ORIGINAL" + + context = InjectableTestContext() + with pytest.raises(ContainerInjectableContextMangled) as py_ex: + with container.injectable_context(context) as current_config: + current_config.current_value = "TEST" + # overwrite the config in container + container.contexts[InjectableTestContext] = InjectableTestContext() + assert py_ex.value.spec == InjectableTestContext + assert py_ex.value.expected_config == context + + +def test_container_provider(container: Container) -> None: + provider = ContextProvider() + # default value will be created + v, k = provider.get_value("n/a", InjectableTestContext) + assert isinstance(v, InjectableTestContext) + assert k == "InjectableTestContext" + assert InjectableTestContext in container + + # provider does not create default value in Container + with pytest.raises(ContextDefaultCannotBeCreated): + provider.get_value("n/a", NoDefaultInjectableContext) + assert NoDefaultInjectableContext not in container + + # explicitly create value + original = NoDefaultInjectableContext() + container.contexts[NoDefaultInjectableContext] = original + v, _ = provider.get_value("n/a", NoDefaultInjectableContext) + assert v is original + + # must assert if namespaces are provided + with pytest.raises(AssertionError): + provider.get_value("n/a", InjectableTestContext, ("ns1",)) + + # type hints that are not classes + literal = Literal["a"] + v, k = provider.get_value("n/a", literal) + assert v is None + assert k == "typing.Literal['a']" + + +def test_container_provider_embedded_inject(container: Container, environment: Any) -> None: + environment["INJECTED"] = "unparsable" + with container.injectable_context(InjectableTestContext(current_value="Embed")) as injected: + # must have top precedence - over the environ provider. environ provider is returning a value that will cannot be parsed + # but the container provider has a precedence and the lookup in environ provider will never happen + C = resolve_configuration(EmbeddedWithInjectableContext()) + assert C.injected.current_value == "Embed" + assert C.injected is injected + # remove first provider + container[ConfigProvidersContext].providers.pop(0) + # now environment will provide unparsable value + with pytest.raises(InvalidInitialValue): + C = resolve_configuration(EmbeddedWithInjectableContext()) diff --git a/tests/common/configuration/test_environ_provider.py b/tests/common/configuration/test_environ_provider.py new file mode 100644 index 0000000000..87cb1de30e --- /dev/null +++ b/tests/common/configuration/test_environ_provider.py @@ -0,0 +1,106 @@ +import pytest +from typing import Any + +from dlt.common.typing import TSecretValue +from dlt.common.configuration import configspec, ConfigFieldMissingException, ConfigFileNotFoundException, resolve +from dlt.common.configuration.specs import RunConfiguration +from dlt.common.configuration.providers import environ as environ_provider + +from tests.utils import preserve_environ +from tests.common.configuration.utils import WrongConfiguration, SecretConfiguration, environment + + +@configspec +class SimpleConfiguration(RunConfiguration): + pipeline_name: str = "Some Name" + test_bool: bool = False + + +@configspec +class SecretKubeConfiguration(RunConfiguration): + pipeline_name: str = "secret kube" + secret_kube: TSecretValue = None + + +@configspec +class MockProdConfigurationVar(RunConfiguration): + pipeline_name: str = "comp" + + + +def test_resolves_from_environ(environment: Any) -> None: + environment["NONECONFIGVAR"] = "Some" + + C = WrongConfiguration() + resolve._resolve_config_fields(C, explicit_namespaces=(), embedded_namespaces=(), accept_partial=False) + assert not C.is_partial() + + assert C.NoneConfigVar == environment["NONECONFIGVAR"] + + +def test_resolves_from_environ_with_coercion(environment: Any) -> None: + environment["TEST_BOOL"] = 'yes' + + C = SimpleConfiguration() + resolve._resolve_config_fields(C, explicit_namespaces=(), embedded_namespaces=(), accept_partial=False) + assert not C.is_partial() + + # value will be coerced to bool + assert C.test_bool is True + + +def test_secret(environment: Any) -> None: + with pytest.raises(ConfigFieldMissingException): + resolve.resolve_configuration(SecretConfiguration()) + environment['SECRET_VALUE'] = "1" + C = resolve.resolve_configuration(SecretConfiguration()) + assert C.secret_value == "1" + # mock the path to point to secret storage + # from dlt.common.configuration import config_utils + path = environ_provider.SECRET_STORAGE_PATH + del environment['SECRET_VALUE'] + try: + # must read a secret file + environ_provider.SECRET_STORAGE_PATH = "./tests/common/cases/%s" + C = resolve.resolve_configuration(SecretConfiguration()) + assert C.secret_value == "BANANA" + + # set some weird path, no secret file at all + del environment['SECRET_VALUE'] + environ_provider.SECRET_STORAGE_PATH = "!C:\\PATH%s" + with pytest.raises(ConfigFieldMissingException): + resolve.resolve_configuration(SecretConfiguration()) + + # set env which is a fallback for secret not as file + environment['SECRET_VALUE'] = "1" + C = resolve.resolve_configuration(SecretConfiguration()) + assert C.secret_value == "1" + finally: + environ_provider.SECRET_STORAGE_PATH = path + + +def test_secret_kube_fallback(environment: Any) -> None: + path = environ_provider.SECRET_STORAGE_PATH + try: + environ_provider.SECRET_STORAGE_PATH = "./tests/common/cases/%s" + C = resolve.resolve_configuration(SecretKubeConfiguration()) + # all unix editors will add x10 at the end of file, it will be preserved + assert C.secret_kube == "kube\n" + # we propagate secrets back to environ and strip the whitespace + assert environment['SECRET_KUBE'] == "kube" + finally: + environ_provider.SECRET_STORAGE_PATH = path + + +def test_configuration_files(environment: Any) -> None: + # overwrite config file paths + environment["CONFIG_FILES_STORAGE_PATH"] = "./tests/common/cases/schemas/ev1/%s" + C = resolve.resolve_configuration(MockProdConfigurationVar()) + assert C.config_files_storage_path == environment["CONFIG_FILES_STORAGE_PATH"] + assert C.has_configuration_file("hasn't") is False + assert C.has_configuration_file("event.schema.json") is True + assert C.get_configuration_file_path("event.schema.json") == "./tests/common/cases/schemas/ev1/event.schema.json" + with C.open_configuration_file("event.schema.json", "r") as f: + f.read() + with pytest.raises(ConfigFileNotFoundException): + C.open_configuration_file("hasn't", "r") diff --git a/tests/common/configuration/test_inject.py b/tests/common/configuration/test_inject.py new file mode 100644 index 0000000000..fa4f72a4a1 --- /dev/null +++ b/tests/common/configuration/test_inject.py @@ -0,0 +1,212 @@ +import inspect +from typing import Any, Optional + +from dlt.common import Decimal +from dlt.common.typing import TSecretValue +from dlt.common.configuration.inject import _spec_from_signature, _get_spec_name_from_f, get_fun_spec, with_config +from dlt.common.configuration.specs import BaseConfiguration, RunConfiguration + +from tests.utils import preserve_environ +from tests.common.configuration.utils import environment + +_DECIMAL_DEFAULT = Decimal("0.01") +_SECRET_DEFAULT = TSecretValue("PASS") +_CONFIG_DEFAULT = RunConfiguration() + + +def test_synthesize_spec_from_sig() -> None: + + # spec from typed signature without defaults + + def f_typed(p1: str, p2: Decimal, p3: Any, p4: Optional[RunConfiguration], p5: TSecretValue) -> None: + pass + + SPEC = _spec_from_signature(f_typed.__name__, inspect.getmodule(f_typed), inspect.signature(f_typed)) + assert SPEC.p1 is None + assert SPEC.p2 is None + assert SPEC.p3 is None + assert SPEC.p4 is None + assert SPEC.p5 is None + fields = SPEC().get_resolvable_fields() + assert fields == {"p1": str, "p2": Decimal, "p3": Any, "p4": Optional[RunConfiguration], "p5": TSecretValue} + + # spec from typed signatures with defaults + + def f_typed_default(t_p1: str = "str", t_p2: Decimal = _DECIMAL_DEFAULT, t_p3: Any = _SECRET_DEFAULT, t_p4: RunConfiguration = _CONFIG_DEFAULT, t_p5: str = None) -> None: + pass + + SPEC = _spec_from_signature(f_typed_default.__name__, inspect.getmodule(f_typed_default), inspect.signature(f_typed_default)) + assert SPEC.t_p1 == "str" + assert SPEC.t_p2 == _DECIMAL_DEFAULT + assert SPEC.t_p3 == _SECRET_DEFAULT + assert isinstance(SPEC.t_p4, RunConfiguration) + assert SPEC.t_p5 is None + fields = SPEC().get_resolvable_fields() + # Any will not assume TSecretValue type because at runtime it's a str + # setting default as None will convert type into optional (t_p5) + assert fields == {"t_p1": str, "t_p2": Decimal, "t_p3": str, "t_p4": RunConfiguration, "t_p5": Optional[str]} + + # spec from untyped signature + + def f_untyped(untyped_p1, untyped_p2) -> None: + pass + + SPEC = _spec_from_signature(f_untyped.__name__, inspect.getmodule(f_untyped), inspect.signature(f_untyped)) + assert SPEC.untyped_p1 is None + assert SPEC.untyped_p2 is None + fields = SPEC().get_resolvable_fields() + assert fields == {"untyped_p1": Any, "untyped_p2": Any,} + + # spec types derived from defaults + + + def f_untyped_default(untyped_p1 = "str", untyped_p2 = _DECIMAL_DEFAULT, untyped_p3 = _CONFIG_DEFAULT, untyped_p4 = None) -> None: + pass + + + SPEC = _spec_from_signature(f_untyped_default.__name__, inspect.getmodule(f_untyped_default), inspect.signature(f_untyped_default)) + assert SPEC.untyped_p1 == "str" + assert SPEC.untyped_p2 == _DECIMAL_DEFAULT + assert isinstance(SPEC.untyped_p3, RunConfiguration) + assert SPEC.untyped_p4 is None + fields = SPEC().get_resolvable_fields() + # untyped_p4 converted to Optional[Any] + assert fields == {"untyped_p1": str, "untyped_p2": Decimal, "untyped_p3": RunConfiguration, "untyped_p4": Optional[Any]} + + # spec from signatures containing positional only and keywords only args + + def f_pos_kw_only(pos_only_1, pos_only_2: str = "default", /, *, kw_only_1, kw_only_2: int = 2) -> None: + pass + + SPEC = _spec_from_signature(f_pos_kw_only.__name__, inspect.getmodule(f_pos_kw_only), inspect.signature(f_pos_kw_only)) + assert SPEC.pos_only_1 is None + assert SPEC.pos_only_2 == "default" + assert SPEC.kw_only_1 is None + assert SPEC.kw_only_2 == 2 + fields = SPEC().get_resolvable_fields() + assert fields == {"pos_only_1": Any, "pos_only_2": str, "kw_only_1": Any, "kw_only_2": int} + + # kw_only = True will filter in keywords only parameters + SPEC = _spec_from_signature(f_pos_kw_only.__name__, inspect.getmodule(f_pos_kw_only), inspect.signature(f_pos_kw_only), kw_only=True) + assert SPEC.kw_only_1 is None + assert SPEC.kw_only_2 == 2 + assert not hasattr(SPEC, "pos_only_1") + fields = SPEC().get_resolvable_fields() + assert fields == {"kw_only_1": Any, "kw_only_2": int} + + def f_variadic(var_1: str, *args, kw_var_1: str, **kwargs) -> None: + pass + + SPEC = _spec_from_signature(f_variadic.__name__, inspect.getmodule(f_variadic), inspect.signature(f_variadic)) + assert SPEC.var_1 is None + assert SPEC.kw_var_1 is None + assert not hasattr(SPEC, "args") + fields = SPEC().get_resolvable_fields() + assert fields == {"var_1": str, "kw_var_1": str} + + +def test_inject_with_non_injectable_param() -> None: + # one of parameters in signature has not valid hint and is skipped (ie. from_pipe) + pass + + +def test_inject_without_spec() -> None: + pass + + +def test_inject_without_spec_kw_only() -> None: + pass + + +def test_inject_with_auto_namespace(environment: Any) -> None: + environment["PIPE__VALUE"] = "test" + + @with_config(auto_namespace=True) + def f(pipeline_name, value): + assert value == "test" + + f("pipe") + + # make sure the spec is available for decorated fun + assert get_fun_spec(f) is not None + assert hasattr(get_fun_spec(f), "pipeline_name") + + +def test_inject_with_spec() -> None: + pass + + +def test_inject_with_str_namespaces() -> None: + # namespaces param is str not tuple + pass + + +def test_inject_with_func_namespace() -> None: + # function to get namespaces from the arguments is provided + pass + + +def test_inject_on_class_and_methods() -> None: + pass + + +def test_set_defaults_for_positional_args() -> None: + # set defaults for positional args that are part of derived SPEC + # set defaults for positional args that are part of provided SPEC + pass + + +def test_inject_spec_remainder_in_kwargs() -> None: + # if the wrapped func contains kwargs then all the fields from spec without matching func args must be injected in kwargs + pass + + +def test_inject_spec_in_kwargs() -> None: + # the resolved spec is injected in kwargs + pass + + +def test_resolved_spec_in_kwargs_pass_through() -> None: + # if last_config is in kwargs then use it and do not resolve it anew + pass + + +def test_inject_spec_into_argument_with_spec_type() -> None: + # if signature contains argument with type of SPEC, it gets injected there + pass + + +def test_initial_spec_from_arg_with_spec_type() -> None: + # if signature contains argument with type of SPEC, get its value to init SPEC (instead of calling the constructor()) + pass + + +def test_auto_derived_spec_type_name() -> None: + + + class AutoNameTest: + @with_config + def __init__(self, pos_par, /, kw_par) -> None: + pass + + @classmethod + @with_config + def make_class(cls, pos_par, /, kw_par) -> None: + pass + + @staticmethod + @with_config + def make_stuff(pos_par, /, kw_par) -> None: + pass + + @with_config + def stuff_test(pos_par, /, kw_par) -> None: + pass + + # name is composed via __qualname__ of func + assert _get_spec_name_from_f(AutoNameTest.__init__) == "TestAutoDerivedSpecTypeNameAutoNameTestInitConfiguration" + # synthesized spec present in current module + assert "TestAutoDerivedSpecTypeNameAutoNameTestInitConfiguration" in globals() + # instantiate + C: BaseConfiguration = globals()["TestAutoDerivedSpecTypeNameAutoNameTestInitConfiguration"]() + assert C.get_resolvable_fields() == {"pos_par": Any, "kw_par": Any} \ No newline at end of file diff --git a/tests/common/configuration/test_namespaces.py b/tests/common/configuration/test_namespaces.py new file mode 100644 index 0000000000..85369c8fdb --- /dev/null +++ b/tests/common/configuration/test_namespaces.py @@ -0,0 +1,266 @@ +import pytest +from typing import Any, Optional +from dlt.common.configuration.container import Container + +from dlt.common.configuration import configspec, ConfigFieldMissingException, resolve, inject_namespace +from dlt.common.configuration.specs import BaseConfiguration, ConfigNamespacesContext +from dlt.common.configuration.exceptions import LookupTrace + +from tests.utils import preserve_environ +from tests.common.configuration.utils import MockProvider, WrongConfiguration, SecretConfiguration, NamespacedConfiguration, environment, mock_provider + + +@configspec +class SingleValConfiguration(BaseConfiguration): + sv: str + + +@configspec +class EmbeddedConfiguration(BaseConfiguration): + sv_config: Optional[SingleValConfiguration] + + +@configspec +class EmbeddedWithNamespacedConfiguration(BaseConfiguration): + embedded: NamespacedConfiguration + + +@configspec +class EmbeddedIgnoredConfiguration(BaseConfiguration): + # underscore prevents the field name to be added to embedded namespaces + _sv_config: Optional[SingleValConfiguration] + + +@configspec +class EmbeddedIgnoredWithNamespacedConfiguration(BaseConfiguration): + _embedded: NamespacedConfiguration + + +@configspec +class EmbeddedWithIgnoredEmbeddedConfiguration(BaseConfiguration): + ignored_embedded: EmbeddedIgnoredWithNamespacedConfiguration + + +def test_namespaced_configuration(environment: Any) -> None: + with pytest.raises(ConfigFieldMissingException) as exc_val: + resolve.resolve_configuration(NamespacedConfiguration()) + + assert list(exc_val.value.traces.keys()) == ["password"] + assert exc_val.value.spec_name == "NamespacedConfiguration" + # check trace + traces = exc_val.value.traces["password"] + # only one provider and namespace was tried + assert len(traces) == 3 + assert traces[0] == LookupTrace("Environment Variables", ["DLT_TEST"], "DLT_TEST__PASSWORD", None) + assert traces[1] == LookupTrace("Pipeline secrets.toml", ["DLT_TEST"], "DLT_TEST.password", None) + assert traces[2] == LookupTrace("Pipeline config.toml", ["DLT_TEST"], "DLT_TEST.password", None) + + # init vars work without namespace + C = resolve.resolve_configuration(NamespacedConfiguration(), initial_value={"password": "PASS"}) + assert C.password == "PASS" + + # env var must be prefixed + environment["PASSWORD"] = "PASS" + with pytest.raises(ConfigFieldMissingException) as exc_val: + resolve.resolve_configuration(NamespacedConfiguration()) + environment["DLT_TEST__PASSWORD"] = "PASS" + C = resolve.resolve_configuration(NamespacedConfiguration()) + assert C.password == "PASS" + + +def test_explicit_namespaces(mock_provider: MockProvider) -> None: + mock_provider.value = "value" + # mock providers separates namespaces with | and key with - + _, k = mock_provider.get_value("key", Any) + assert k == "-key" + _, k = mock_provider.get_value("key", Any, "ns1") + assert k == "ns1-key" + _, k = mock_provider.get_value("key", Any, "ns1", "ns2") + assert k == "ns1|ns2-key" + + # via make configuration + mock_provider.reset_stats() + resolve.resolve_configuration(SingleValConfiguration()) + assert mock_provider.last_namespace == () + mock_provider.reset_stats() + resolve.resolve_configuration(SingleValConfiguration(), namespaces=("ns1",)) + # value is returned only on empty namespace + assert mock_provider.last_namespace == () + # always start with more precise namespace + assert mock_provider.last_namespaces == [("ns1",), ()] + mock_provider.reset_stats() + resolve.resolve_configuration(SingleValConfiguration(), namespaces=("ns1", "ns2")) + assert mock_provider.last_namespaces == [("ns1", "ns2"), ("ns1",), ()] + + +def test_explicit_namespaces_with_namespaced_config(mock_provider: MockProvider) -> None: + mock_provider.value = "value" + # with namespaced config + mock_provider.return_value_on = ("DLT_TEST",) + resolve.resolve_configuration(NamespacedConfiguration()) + assert mock_provider.last_namespace == ("DLT_TEST",) + # first the native representation of NamespacedConfiguration is queried with (), and then the fields in NamespacedConfiguration are queried only in DLT_TEST + assert mock_provider.last_namespaces == [(), ("DLT_TEST",)] + # namespaced config is always innermost + mock_provider.reset_stats() + resolve.resolve_configuration(NamespacedConfiguration(), namespaces=("ns1",)) + assert mock_provider.last_namespaces == [("ns1",), (), ("ns1", "DLT_TEST"), ("DLT_TEST",)] + mock_provider.reset_stats() + resolve.resolve_configuration(NamespacedConfiguration(), namespaces=("ns1", "ns2")) + assert mock_provider.last_namespaces == [("ns1", "ns2"), ("ns1",), (), ("ns1", "ns2", "DLT_TEST"), ("ns1", "DLT_TEST"), ("DLT_TEST",)] + + +def test_overwrite_config_namespace_from_embedded(mock_provider: MockProvider) -> None: + mock_provider.value = {} + mock_provider.return_value_on = ("embedded",) + resolve.resolve_configuration(EmbeddedWithNamespacedConfiguration()) + # when resolving the config namespace DLT_TEST was removed and the embedded namespace was used instead + assert mock_provider.last_namespace == ("embedded",) + # lookup in order: () - parent config when looking for "embedded", then from "embedded" config + assert mock_provider.last_namespaces == [(), ("embedded",)] + + +def test_explicit_namespaces_from_embedded_config(mock_provider: MockProvider) -> None: + mock_provider.value = {"sv": "A"} + mock_provider.return_value_on = ("sv_config",) + c = resolve.resolve_configuration(EmbeddedConfiguration()) + # we mock the dictionary below as the value for all requests + assert c.sv_config.sv == '{"sv": "A"}' + # following namespaces were used when resolving EmbeddedConfig: () trying to get initial value for the whole embedded sv_config, then ("sv_config",), () to resolve sv in sv_config + assert mock_provider.last_namespaces == [(), ("sv_config",)] + # embedded namespace inner of explicit + mock_provider.reset_stats() + resolve.resolve_configuration(EmbeddedConfiguration(), namespaces=("ns1",)) + assert mock_provider.last_namespaces == [("ns1",), (), ("ns1", "sv_config",), ("sv_config",)] + + +def test_ignore_embedded_namespace_by_field_name(mock_provider: MockProvider) -> None: + mock_provider.value = {"sv": "A"} + resolve.resolve_configuration(EmbeddedIgnoredConfiguration()) + # _sv_config will not be added to embedded namespaces and looked up + assert mock_provider.last_namespaces == [()] + mock_provider.reset_stats() + resolve.resolve_configuration(EmbeddedIgnoredConfiguration(), namespaces=("ns1",)) + assert mock_provider.last_namespaces == [("ns1",), ()] + # if namespace config exist, it won't be replaced by embedded namespace + mock_provider.reset_stats() + mock_provider.value = {} + mock_provider.return_value_on = ("DLT_TEST",) + resolve.resolve_configuration(EmbeddedIgnoredWithNamespacedConfiguration()) + assert mock_provider.last_namespaces == [(), ("DLT_TEST",)] + # embedded configuration of depth 2: first normal, second - ignored + mock_provider.reset_stats() + mock_provider.return_value_on = ("DLT_TEST",) + resolve.resolve_configuration(EmbeddedWithIgnoredEmbeddedConfiguration()) + assert mock_provider.last_namespaces == [(), ('ignored_embedded',), ('ignored_embedded', 'DLT_TEST'), ('DLT_TEST',)] + + +def test_injected_namespaces(mock_provider: MockProvider) -> None: + container = Container() + mock_provider.value = "value" + + with container.injectable_context(ConfigNamespacesContext(namespaces=("inj-ns1",))): + resolve.resolve_configuration(SingleValConfiguration()) + assert mock_provider.last_namespaces == [("inj-ns1",), ()] + mock_provider.reset_stats() + # explicit namespace preempts injected namespace + resolve.resolve_configuration(SingleValConfiguration(), namespaces=("ns1",)) + assert mock_provider.last_namespaces == [("ns1",), ()] + # namespaced config inner of injected + mock_provider.reset_stats() + mock_provider.return_value_on = ("DLT_TEST",) + resolve.resolve_configuration(NamespacedConfiguration()) + assert mock_provider.last_namespaces == [("inj-ns1",), (), ("inj-ns1", "DLT_TEST"), ("DLT_TEST",)] + # injected namespace inner of ns coming from embedded config + mock_provider.reset_stats() + mock_provider.return_value_on = () + mock_provider.value = {"sv": "A"} + resolve.resolve_configuration(EmbeddedConfiguration()) + # first we look for sv_config -> ("inj-ns1",), () then we look for sv + assert mock_provider.last_namespaces == [("inj-ns1",), (), ("inj-ns1", "sv_config"), ("sv_config",)] + + # multiple injected namespaces + with container.injectable_context(ConfigNamespacesContext(namespaces=("inj-ns1", "inj-ns2"))): + mock_provider.reset_stats() + resolve.resolve_configuration(SingleValConfiguration()) + assert mock_provider.last_namespaces == [("inj-ns1", "inj-ns2"), ("inj-ns1",), ()] + mock_provider.reset_stats() + + +def test_namespace_with_pipeline_name(mock_provider: MockProvider) -> None: + # if pipeline name is present, keys will be looked up twice: with pipeline as top level namespace and without it + + container = Container() + mock_provider.value = "value" + + with container.injectable_context(ConfigNamespacesContext(pipeline_name="PIPE")): + mock_provider.return_value_on = () + resolve.resolve_configuration(SingleValConfiguration()) + assert mock_provider.last_namespaces == [("PIPE",), ()] + + mock_provider.reset_stats() + resolve.resolve_configuration(SingleValConfiguration(), namespaces=("ns1",)) + # PIPE namespace is exhausted then another lookup without PIPE + assert mock_provider.last_namespaces == [("PIPE", "ns1"), ("PIPE",), ("ns1",), ()] + + mock_provider.return_value_on = ("PIPE", ) + mock_provider.reset_stats() + resolve.resolve_configuration(SingleValConfiguration(), namespaces=("ns1",)) + assert mock_provider.last_namespaces == [("PIPE", "ns1"), ("PIPE",)] + + # with both pipe and config namespaces are always present in lookup + # "PIPE", "DLT_TEST" + mock_provider.return_value_on = () + mock_provider.reset_stats() + # () will never be searched + with pytest.raises(ConfigFieldMissingException): + resolve.resolve_configuration(NamespacedConfiguration()) + mock_provider.return_value_on = ("DLT_TEST",) + mock_provider.reset_stats() + resolve.resolve_configuration(NamespacedConfiguration()) + # first the whole NamespacedConfiguration is looked under key DLT_TEST (namespaces: ('PIPE',), ()), then fields of NamespacedConfiguration + assert mock_provider.last_namespaces == [('PIPE',), (), ("PIPE", "DLT_TEST"), ("DLT_TEST",)] + + # with pipeline and injected namespaces + with container.injectable_context(ConfigNamespacesContext(pipeline_name="PIPE", namespaces=("inj-ns1",))): + mock_provider.return_value_on = () + mock_provider.reset_stats() + resolve.resolve_configuration(SingleValConfiguration()) + assert mock_provider.last_namespaces == [("PIPE", "inj-ns1"), ("PIPE",), ("inj-ns1",), ()] + + +# def test_namespaces_with_duplicate(mock_provider: MockProvider) -> None: +# container = Container() +# mock_provider.value = "value" + +# with container.injectable_context(ConfigNamespacesContext(pipeline_name="DLT_TEST", namespaces=("DLT_TEST", "DLT_TEST"))): +# mock_provider.return_value_on = ("DLT_TEST",) +# resolve.resolve_configuration(NamespacedConfiguration(), namespaces=("DLT_TEST", "DLT_TEST")) +# # no duplicates are removed, duplicates are misconfiguration +# # note: use dict.fromkeys to create ordered sets from lists if we ever want to remove duplicates +# # the lookup tuples are create as follows: +# # 1. (pipeline name, deduplicated namespaces, config namespace) +# # 2. (deduplicated namespaces, config namespace) +# # 3. (pipeline name, config namespace) +# # 4. (config namespace) +# assert mock_provider.last_namespaces == [("DLT_TEST", "DLT_TEST", "DLT_TEST", "DLT_TEST"), ("DLT_TEST", "DLT_TEST", "DLT_TEST"), ("DLT_TEST", "DLT_TEST"), ("DLT_TEST", "DLT_TEST"), ("DLT_TEST",)] + + +def test_inject_namespace(mock_provider: MockProvider) -> None: + mock_provider.value = "value" + + with inject_namespace(ConfigNamespacesContext(pipeline_name="PIPE", namespaces=("inj-ns1",))): + resolve.resolve_configuration(SingleValConfiguration()) + assert mock_provider.last_namespaces == [("PIPE", "inj-ns1"), ("PIPE",), ("inj-ns1",), ()] + + # inject with merge previous + with inject_namespace(ConfigNamespacesContext(namespaces=("inj-ns2",))): + mock_provider.reset_stats() + resolve.resolve_configuration(SingleValConfiguration()) + assert mock_provider.last_namespaces == [("PIPE", "inj-ns2"), ("PIPE",), ("inj-ns2",), ()] + + # inject without merge + mock_provider.reset_stats() + with inject_namespace(ConfigNamespacesContext(), merge_existing=False): + resolve.resolve_configuration(SingleValConfiguration()) + assert mock_provider.last_namespaces == [()] diff --git a/tests/common/configuration/test_providers.py b/tests/common/configuration/test_providers.py new file mode 100644 index 0000000000..2e88a7af58 --- /dev/null +++ b/tests/common/configuration/test_providers.py @@ -0,0 +1,17 @@ +def test_providers_order() -> None: + pass + + +def test_add_remove_providers() -> None: + # TODO: we should be able to add and remove providers + pass + + +def test_providers_autodetect_and_config() -> None: + # TODO: toml based and remote vaults should be configured and/or autodetected + pass + + +def test_providers_value_getter() -> None: + # TODO: it should be possible to get a value from providers' chain via `config` and `secrets` objects via indexer (nested) or explicit key, *namespaces getter + pass \ No newline at end of file diff --git a/tests/common/configuration/test_toml_provider.py b/tests/common/configuration/test_toml_provider.py new file mode 100644 index 0000000000..304c3b59d1 --- /dev/null +++ b/tests/common/configuration/test_toml_provider.py @@ -0,0 +1,169 @@ +import pytest +from typing import Any, Iterator +import datetime # noqa: I251 + + +from dlt.common import pendulum +from dlt.common.configuration import configspec, ConfigFieldMissingException, ConfigFileNotFoundException, resolve +from dlt.common.configuration.container import Container +from dlt.common.configuration.inject import with_config +from dlt.common.configuration.exceptions import LookupTrace +from dlt.common.configuration.providers.environ import EnvironProvider +from dlt.common.configuration.providers.toml import SecretsTomlProvider, ConfigTomlProvider +from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext +from dlt.common.configuration.specs import BaseConfiguration, GcpClientCredentials, PostgresCredentials +from dlt.common.typing import TSecretValue + +from tests.utils import preserve_environ +from tests.common.configuration.utils import WithCredentialsConfiguration, CoercionTestConfiguration, COERCIONS, SecretConfiguration, environment + + +@configspec +class EmbeddedWithGcpStorage(BaseConfiguration): + gcp_storage: GcpClientCredentials + + +@configspec +class EmbeddedWithGcpCredentials(BaseConfiguration): + credentials: GcpClientCredentials + + +@pytest.fixture +def providers() -> Iterator[ConfigProvidersContext]: + pipeline_root = "./tests/common/cases/configuration/.dlt" + ctx = ConfigProvidersContext() + ctx.providers.clear() + ctx.add_provider(SecretsTomlProvider(project_dir=pipeline_root)) + ctx.add_provider(ConfigTomlProvider(project_dir=pipeline_root)) + with Container().injectable_context(ctx): + yield ctx + + +def test_secrets_from_toml_secrets() -> None: + with pytest.raises(ConfigFieldMissingException) as py_ex: + resolve.resolve_configuration(SecretConfiguration()) + + # only two traces because TSecretValue won't be checked in config.toml provider + traces = py_ex.value.traces["secret_value"] + assert len(traces) == 2 + assert traces[0] == LookupTrace("Environment Variables", [], "SECRET_VALUE", None) + assert traces[1] == LookupTrace("Pipeline secrets.toml", [], "secret_value", None) + + with pytest.raises(ConfigFieldMissingException) as py_ex: + resolve.resolve_configuration(WithCredentialsConfiguration()) + + +def test_toml_types(providers: ConfigProvidersContext) -> None: + # resolve CoercionTestConfiguration from typecheck namespace + c = resolve.resolve_configuration(CoercionTestConfiguration(), namespaces=("typecheck",)) + for k, v in COERCIONS.items(): + # toml does not know tuples + if isinstance(v, tuple): + v = list(v) + if isinstance(v, datetime.datetime): + v = pendulum.parse("1979-05-27T07:32:00-08:00") + assert v == c[k] + + +def test_config_provider_order(providers: ConfigProvidersContext, environment: Any) -> None: + + # add env provider + providers.providers.insert(0, EnvironProvider()) + + @with_config(namespaces=("api",)) + def single_val(port): + return port + + # secrets have api.port=1023 and this will be used + assert single_val() == 1023 + + # env will make it string, also namespace is optional + environment["PORT"] = "UNKNOWN" + assert single_val() == "UNKNOWN" + + environment["API__PORT"] = "1025" + assert single_val() == "1025" + + +def test_toml_mixed_config_inject(providers: ConfigProvidersContext) -> None: + # get data from both providers + + @with_config + def mixed_val(api_type, secret_value: TSecretValue, typecheck: Any): + return api_type, secret_value, typecheck + + _tup = mixed_val() + assert _tup[0] == "REST" + assert _tup[1] == "2137" + assert isinstance(_tup[2], dict) + + +def test_toml_namespaces(providers: ConfigProvidersContext) -> None: + cfg = providers["Pipeline config.toml"] + assert cfg.get_value("api_type", str) == ("REST", "api_type") + assert cfg.get_value("port", int, "api") == (1024, "api.port") + assert cfg.get_value("param1", str, "api", "params") == ("a", "api.params.param1") + + +def test_secrets_toml_credentials(providers: ConfigProvidersContext) -> None: + # there are credentials exactly under destination.bigquery.credentials + c = resolve.resolve_configuration(GcpClientCredentials(), namespaces=("destination", "bigquery")) + assert c.project_id.endswith("destination.bigquery.credentials") + # there are no destination.gcp_storage.credentials so it will fallback to "destination"."credentials" + c = resolve.resolve_configuration(GcpClientCredentials(), namespaces=("destination", "gcp_storage")) + assert c.project_id.endswith("destination.credentials") + # also explicit + c = resolve.resolve_configuration(GcpClientCredentials(), namespaces=("destination",)) + assert c.project_id.endswith("destination.credentials") + # there's "credentials" key but does not contain valid gcp credentials + with pytest.raises(ConfigFieldMissingException): + resolve.resolve_configuration(GcpClientCredentials()) + # also try postgres credentials + c = resolve.resolve_configuration(PostgresCredentials(), namespaces=("destination", "redshift")) + assert c.dbname == "destination.redshift.credentials" + # bigquery credentials do not match redshift credentials + with pytest.raises(ConfigFieldMissingException): + resolve.resolve_configuration(PostgresCredentials(), namespaces=("destination", "bigquery")) + + + +def test_secrets_toml_embedded_credentials(providers: ConfigProvidersContext) -> None: + # will try destination.bigquery.credentials + c = resolve.resolve_configuration(EmbeddedWithGcpCredentials(), namespaces=("destination", "bigquery")) + assert c.credentials.project_id.endswith("destination.bigquery.credentials") + # will try destination.gcp_storage.credentials and fallback to destination.credentials + c = resolve.resolve_configuration(EmbeddedWithGcpCredentials(), namespaces=("destination", "gcp_storage")) + assert c.credentials.project_id.endswith("destination.credentials") + # will try everything until credentials in the root where incomplete credentials are present + c = EmbeddedWithGcpCredentials() + # create embedded config that will be passed as initial + c.credentials = GcpClientCredentials() + with pytest.raises(ConfigFieldMissingException) as py_ex: + resolve.resolve_configuration(c, namespaces=("middleware", "storage")) + # so we can read partially filled configuration here + assert c.credentials.project_id.endswith("-credentials") + assert set(py_ex.value.traces.keys()) == {"client_email", "private_key"} + + # embed "gcp_storage" will bubble up to the very top, never reverts to "credentials" + c = resolve.resolve_configuration(EmbeddedWithGcpStorage(), namespaces=("destination", "bigquery")) + assert c.gcp_storage.project_id.endswith("-gcp-storage") + + # also explicit + c = resolve.resolve_configuration(GcpClientCredentials(), namespaces=("destination",)) + assert c.project_id.endswith("destination.credentials") + # there's "credentials" key but does not contain valid gcp credentials + with pytest.raises(ConfigFieldMissingException): + resolve.resolve_configuration(GcpClientCredentials()) + + +# def test_secrets_toml_ignore_dict_initial(providers: ConfigProvidersContext) -> None: + + + +def test_secrets_toml_credentials_from_native_repr(providers: ConfigProvidersContext) -> None: + # cfg = providers["Pipeline secrets.toml"] + # print(cfg._toml) + # print(cfg._toml["source"]["credentials"]) + # resolve gcp_credentials by parsing initial value which is str holding json doc + c = resolve.resolve_configuration(GcpClientCredentials(), namespaces=("source",)) + assert c.project_id.endswith("source.credentials") diff --git a/tests/common/configuration/utils.py b/tests/common/configuration/utils.py new file mode 100644 index 0000000000..e03a123c47 --- /dev/null +++ b/tests/common/configuration/utils.py @@ -0,0 +1,146 @@ +import pytest +from os import environ +import datetime # noqa: I251 +from typing import Any, List, Optional, Tuple, Type, Dict, MutableMapping, Optional, Sequence + +from dlt.common import Decimal, pendulum +from dlt.common.configuration.container import Container +from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext +from dlt.common.typing import TSecretValue, StrAny +from dlt.common.configuration import configspec +from dlt.common.configuration.providers import Provider +from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration, RunConfiguration + + +@configspec +class WrongConfiguration(RunConfiguration): + pipeline_name: str = "Some Name" + NoneConfigVar: str = None + log_color: bool = True + + +@configspec +class CoercionTestConfiguration(RunConfiguration): + pipeline_name: str = "Some Name" + str_val: str = None + int_val: int = None + bool_val: bool = None + list_val: list = None # type: ignore + dict_val: dict = None # type: ignore + bytes_val: bytes = None + float_val: float = None + tuple_val: Tuple[int, int, StrAny] = None + any_val: Any = None + none_val: str = None + COMPLEX_VAL: Dict[str, Tuple[int, List[str], List[str]]] = None + date_val: datetime.datetime = None + dec_val: Decimal = None + sequence_val: Sequence[str] = None + gen_list_val: List[str] = None + mapping_val: StrAny = None + mutable_mapping_val: MutableMapping[str, str] = None + + +@configspec +class SecretConfiguration(BaseConfiguration): + secret_value: TSecretValue = None + + +@configspec +class SecretCredentials(CredentialsConfiguration): + secret_value: TSecretValue = None + + +@configspec +class WithCredentialsConfiguration(BaseConfiguration): + credentials: SecretCredentials + + +@configspec +class NamespacedConfiguration(BaseConfiguration): + __namespace__ = "DLT_TEST" + + password: str = None + + +@pytest.fixture(scope="function") +def environment() -> Any: + environ.clear() + return environ + + +@pytest.fixture(scope="function") +def mock_provider() -> "MockProvider": + container = Container() + with container.injectable_context(ConfigProvidersContext()) as providers: + # replace all providers with MockProvider that does not support secrets + mock_provider = MockProvider() + providers.providers = [mock_provider] + yield mock_provider + + +class MockProvider(Provider): + + def __init__(self) -> None: + self.value: Any = None + self.return_value_on: Tuple[str] = () + self.reset_stats() + + def reset_stats(self) -> None: + self.last_namespace: Tuple[str] = None + self.last_namespaces: List[Tuple[str]] = [] + + def get_value(self, key: str, hint: Type[Any], *namespaces: str) -> Tuple[Optional[Any], str]: + self.last_namespace = namespaces + self.last_namespaces.append(namespaces) + # print("|".join(namespaces) + "-" + key) + if namespaces == self.return_value_on: + rv = self.value + else: + rv = None + return rv, "|".join(namespaces) + "-" + key + + @property + def supports_secrets(self) -> bool: + return False + + @property + def supports_namespaces(self) -> bool: + return True + + @property + def name(self) -> str: + return "Mock Provider" + + +class SecretMockProvider(MockProvider): + @property + def supports_secrets(self) -> bool: + return True + + +COERCIONS = { + 'str_val': 'test string', + 'int_val': 12345, + 'bool_val': True, + 'list_val': [1, "2", [3]], + 'dict_val': { + 'a': 1, + "b": "2" + }, + 'bytes_val': b'Hello World!', + 'float_val': 1.18927, + "tuple_val": (1, 2, {"1": "complicated dicts allowed in literal eval"}), + 'any_val': "function() {}", + 'none_val': "none", + 'COMPLEX_VAL': { + "_": [1440, ["*"], []], + "change-email": [560, ["*"], []] + }, + "date_val": pendulum.now(), + "dec_val": Decimal("22.38"), + "sequence_val": ["A", "B", "KAPPA"], + "gen_list_val": ["C", "Z", "N"], + "mapping_val": {"FL": 1, "FR": {"1": 2}}, + "mutable_mapping_val": {"str": "str"} +} \ No newline at end of file diff --git a/tests/common/normalizers/test_json_relational.py b/tests/common/normalizers/test_json_relational.py index b432768d7d..a9dbe5a2ba 100644 --- a/tests/common/normalizers/test_json_relational.py +++ b/tests/common/normalizers/test_json_relational.py @@ -4,7 +4,6 @@ from dlt.common.utils import digest128, uniq_id from dlt.common.schema import Schema from dlt.common.schema.utils import new_table -from dlt.common.sources import DLT_METADATA_FIELD, with_table_name from dlt.common.normalizers.json.relational import JSONNormalizerConfigPropagation, _flatten, _get_child_row_hash, _normalize_row, normalize_data_item @@ -508,6 +507,19 @@ def test_preserves_complex_types_list(schema: Schema) -> None: assert root_row[1]["value"] == row["value"] +def test_wrap_in_dict(schema: Schema) -> None: + # json normalizer wraps in dict + row = list(schema.normalize_data_item(schema, 1, "load_id", "simplex"))[0][1] + assert row["value"] == 1 + assert row["_dlt_load_id"] == "load_id" + # wrap a list + rows = list(schema.normalize_data_item(schema, [1, 2, 3, 4, "A"], "load_id", "listex")) + assert len(rows) == 6 + assert rows[0][0] == ("listex", None,) + assert rows[1][0] == ("listex__value", "listex") + assert rows[-1][1]["value"] == "A" + + def test_complex_types_for_recursion_level(schema: Schema) -> None: add_dlt_root_id_propagation(schema) # if max recursion depth is set, nested elements will be kept as complex @@ -519,13 +531,13 @@ def test_complex_types_for_recursion_level(schema: Schema) -> None: "lo": [{"e": {"v": 1}}] # , {"e": {"v": 2}}, {"e":{"v":3 }} }] } - n_rows_nl = list(schema.normalize_data_item(schema, row, "load_id")) + n_rows_nl = list(schema.normalize_data_item(schema, row, "load_id", "default")) # all nested elements were yielded assert ["default", "default__f", "default__f__l", "default__f__lo"] == [r[0][0] for r in n_rows_nl] # set max nesting to 0 schema._normalizers_config["json"]["config"]["max_nesting"] = 0 - n_rows = list(schema.normalize_data_item(schema, row, "load_id")) + n_rows = list(schema.normalize_data_item(schema, row, "load_id", "default")) # the "f" element is left as complex type and not normalized assert len(n_rows) == 1 assert n_rows[0][0][0] == "default" @@ -534,7 +546,7 @@ def test_complex_types_for_recursion_level(schema: Schema) -> None: # max nesting 1 schema._normalizers_config["json"]["config"]["max_nesting"] = 1 - n_rows = list(schema.normalize_data_item(schema, row, "load_id")) + n_rows = list(schema.normalize_data_item(schema, row, "load_id", "default")) assert len(n_rows) == 2 assert ["default", "default__f"] == [r[0][0] for r in n_rows] # on level f, "l" and "lo" are not normalized @@ -545,7 +557,7 @@ def test_complex_types_for_recursion_level(schema: Schema) -> None: # max nesting 2 schema._normalizers_config["json"]["config"]["max_nesting"] = 2 - n_rows = list(schema.normalize_data_item(schema, row, "load_id")) + n_rows = list(schema.normalize_data_item(schema, row, "load_id", "default")) assert len(n_rows) == 4 # in default__f__lo the dicts that would be flattened are complex types last_row = n_rows[3] @@ -553,7 +565,7 @@ def test_complex_types_for_recursion_level(schema: Schema) -> None: # max nesting 3 schema._normalizers_config["json"]["config"]["max_nesting"] = 3 - n_rows = list(schema.normalize_data_item(schema, row, "load_id")) + n_rows = list(schema.normalize_data_item(schema, row, "load_id", "default")) assert n_rows_nl == n_rows @@ -570,13 +582,11 @@ def test_extract_with_table_name_meta() -> None: } # force table name rows = list( - normalize_data_item(create_schema_with_name("discord"), with_table_name(row, "channel"), "load_id") + normalize_data_item(create_schema_with_name("discord"), row, "load_id", "channel") ) # table is channel assert rows[0][0][0] == "channel" normalized_row = rows[0][1] - # _dlt_meta must be removed must be removed - assert DLT_METADATA_FIELD not in normalized_row assert normalized_row["guild_id"] == "815421435900198962" assert "_dlt_id" in normalized_row assert normalized_row["_dlt_load_id"] == "load_id" @@ -588,7 +598,7 @@ def test_table_name_meta_normalized() -> None: } # force table name rows = list( - normalize_data_item(create_schema_with_name("discord"), with_table_name(row, "channelSURFING"), "load_id") + normalize_data_item(create_schema_with_name("discord"), row, "load_id", "channelSURFING") ) # table is channel assert rows[0][0][0] == "channel_surfing" @@ -607,7 +617,7 @@ def test_parse_with_primary_key() -> None: "wo_id": [1, 2, 3] }] } - rows = list(normalize_data_item(schema, row, "load_id")) + rows = list(normalize_data_item(schema, row, "load_id", "discord")) # get root root = next(t[1] for t in rows if t[0][0] == "discord") assert root["_dlt_id"] == digest128("817949077341208606") @@ -633,7 +643,7 @@ def test_parse_with_primary_key() -> None: def test_keeps_none_values() -> None: row = {"a": None, "timestamp": 7} - rows = list(normalize_data_item(create_schema_with_name("other"), row, "1762162.1212")) + rows = list(normalize_data_item(create_schema_with_name("other"), row, "1762162.1212", "other")) table_name = rows[0][0][0] assert table_name == "other" normalized_row = rows[0][1] diff --git a/tests/common/runners/test_runnable.py b/tests/common/runners/test_runnable.py index dd883e0254..c5eb276dda 100644 --- a/tests/common/runners/test_runnable.py +++ b/tests/common/runners/test_runnable.py @@ -1,7 +1,9 @@ import gc +import pytest from multiprocessing.pool import Pool from multiprocessing.dummy import Pool as ThreadPool -import pytest + +from dlt.normalize.configuration import SchemaVolumeConfiguration from tests.common.runners.utils import _TestRunnable from tests.utils import skipifspawn @@ -62,3 +64,24 @@ def test_weak_pool_ref() -> None: # weak reference will be removed from container with pytest.raises(KeyError): r = wref[rid] + + +def test_configuredworker() -> None: + # call worker method with CONFIG values that should be restored into CONFIG type + config = SchemaVolumeConfiguration() + config["import_schema_path"] = "test_schema_path" + _worker_1(config, "PX1", par2="PX2") + + # must also work across process boundary + with Pool(1) as p: + p.starmap(_worker_1, [(config, "PX1", "PX2")]) + + +def _worker_1(CONFIG: SchemaVolumeConfiguration, par1: str, par2: str = "DEFAULT") -> None: + # a correct type was passed + assert type(CONFIG) is SchemaVolumeConfiguration + # check if config values are restored + assert CONFIG.import_schema_path == "test_schema_path" + # check if other parameters are correctly + assert par1 == "PX1" + assert par2 == "PX2" \ No newline at end of file diff --git a/tests/common/runners/test_runners.py b/tests/common/runners/test_runners.py index fc1e5df55b..35ac059371 100644 --- a/tests/common/runners/test_runners.py +++ b/tests/common/runners/test_runners.py @@ -5,48 +5,53 @@ from dlt.cli import TRunnerArgs from dlt.common import signals -from dlt.common.typing import StrAny -from dlt.common.configuration import PoolRunnerConfiguration, make_configuration -from dlt.common.configuration.pool_runner_configuration import TPoolType +from dlt.common.configuration import resolve_configuration, configspec +from dlt.common.configuration.specs import PoolRunnerConfiguration, TPoolType from dlt.common.exceptions import DltException, SignalReceivedException, TimeRangeExhaustedException, UnsupportedProcessStartMethodException from dlt.common.runners import pool_runner as runner from tests.common.runners.utils import _TestRunnable from tests.utils import init_logger + +@configspec class ModPoolRunnerConfiguration(PoolRunnerConfiguration): - IS_SINGLE_RUN: bool = True - WAIT_RUNS: int = 1 - PIPELINE_NAME: str = "testrunners" - POOL_TYPE: TPoolType = "none" - RUN_SLEEP: float = 0.1 - RUN_SLEEP_IDLE: float = 0.1 - RUN_SLEEP_WHEN_FAILED: float = 0.1 + is_single_run: bool = True + wait_runs: int = 1 + pipeline_name: str = "testrunners" + pool_type: TPoolType = "none" + run_sleep: float = 0.1 + run_sleep_idle: float = 0.1 + run_sleep_when_failed: float = 0.1 +@configspec class StopExceptionRunnerConfiguration(ModPoolRunnerConfiguration): - EXIT_ON_EXCEPTION: bool = True + exit_on_exception: bool = True +@configspec class LimitedPoolRunnerConfiguration(ModPoolRunnerConfiguration): - STOP_AFTER_RUNS: int = 5 + stop_after_runs: int = 5 +@configspec class ProcessPoolConfiguration(ModPoolRunnerConfiguration): - POOL_TYPE: TPoolType = "process" + pool_type: TPoolType = "process" +@configspec class ThreadPoolConfiguration(ModPoolRunnerConfiguration): - POOL_TYPE: TPoolType = "thread" + pool_type: TPoolType = "thread" -def configure(C: Type[PoolRunnerConfiguration], args: TRunnerArgs) -> Type[PoolRunnerConfiguration]: - return make_configuration(C, C, initial_values=args._asdict()) +def configure(C: Type[PoolRunnerConfiguration], args: TRunnerArgs) -> PoolRunnerConfiguration: + return resolve_configuration(C(), initial_value=args._asdict()) @pytest.fixture(scope="module", autouse=True) def logger_autouse() -> None: - init_logger(ModPoolRunnerConfiguration) + init_logger() @pytest.fixture(autouse=True) diff --git a/tests/common/schema/custom_normalizers.py b/tests/common/schema/custom_normalizers.py index 799d7a637e..a69b22df1f 100644 --- a/tests/common/schema/custom_normalizers.py +++ b/tests/common/schema/custom_normalizers.py @@ -12,11 +12,15 @@ def normalize_column_name(name: str) -> str: return "column_" + name.lower() +def normalize_schema_name(name: str) -> str: + return name.lower() + + def extend_schema(schema: Schema) -> None: json_config = schema._normalizers_config["json"]["config"] d_h = schema._settings.setdefault("default_hints", {}) d_h["not_null"] = json_config["not_null"] -def normalize_data_item(schema: Schema, source_event: TDataItem, load_id: str) -> TNormalizedRowIterator: - yield ("table", None), source_event +def normalize_data_item(schema: Schema, source_event: TDataItem, load_id: str, table_name) -> TNormalizedRowIterator: + yield (table_name, None), source_event diff --git a/tests/common/schema/test_coercion.py b/tests/common/schema/test_coercion.py index 68170b2ee8..2c82a95429 100644 --- a/tests/common/schema/test_coercion.py +++ b/tests/common/schema/test_coercion.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping, MutableSequence from typing import Any, Type import pytest import datetime # noqa: I251 @@ -155,13 +156,13 @@ def test_coerce_type_to_timestamp() -> None: # test wrong unix timestamps with pytest.raises(ValueError): - print(utils.coerce_type("timestamp", "double", -1000000000000000000000000000)) + utils.coerce_type("timestamp", "double", -1000000000000000000000000000) with pytest.raises(ValueError): - print(utils.coerce_type("timestamp", "double", 1000000000000000000000000000)) + utils.coerce_type("timestamp", "double", 1000000000000000000000000000) # formats with timezones are not parsed with pytest.raises(ValueError): - print(utils.coerce_type("timestamp", "text", "06/04/22, 11:15PM IST")) + utils.coerce_type("timestamp", "text", "06/04/22, 11:15PM IST") # we do not parse RFC 822, 2822, 850 etc. with pytest.raises(ValueError): @@ -198,13 +199,24 @@ def test_py_type_to_sc_type() -> None: assert utils.py_type_to_sc_type(int) == "bigint" assert utils.py_type_to_sc_type(float) == "double" assert utils.py_type_to_sc_type(str) == "text" - # unknown types are recognized as text - assert utils.py_type_to_sc_type(Exception) == "text" assert utils.py_type_to_sc_type(type(pendulum.now())) == "timestamp" assert utils.py_type_to_sc_type(type(datetime.datetime(1988, 12, 1))) == "timestamp" assert utils.py_type_to_sc_type(type(Decimal(1))) == "decimal" assert utils.py_type_to_sc_type(type(HexBytes("0xFF"))) == "binary" assert utils.py_type_to_sc_type(type(Wei.from_int256(2137, decimals=2))) == "wei" + # unknown types raise TypeException + with pytest.raises(TypeError): + utils.py_type_to_sc_type(Any) + # none type raises TypeException + with pytest.raises(TypeError): + utils.py_type_to_sc_type(type(None)) + # complex types + assert utils.py_type_to_sc_type(list) == "complex" + # assert utils.py_type_to_sc_type(set) == "complex" + assert utils.py_type_to_sc_type(dict) == "complex" + assert utils.py_type_to_sc_type(tuple) == "complex" + assert utils.py_type_to_sc_type(Mapping) == "complex" + assert utils.py_type_to_sc_type(MutableSequence) == "complex" def test_coerce_type_complex() -> None: diff --git a/tests/common/schema/test_filtering.py b/tests/common/schema/test_filtering.py index ddec6ac34e..03a2e6de4a 100644 --- a/tests/common/schema/test_filtering.py +++ b/tests/common/schema/test_filtering.py @@ -1,7 +1,6 @@ import pytest from copy import deepcopy from dlt.common.schema.exceptions import ParentTableNotFoundException -from dlt.common.sources import with_table_name from dlt.common.typing import StrAny from dlt.common.schema import Schema @@ -73,7 +72,7 @@ def test_filter_parent_table_schema_update(schema: Schema) -> None: updates = [] - for (t, p), row in schema.normalize_data_item(schema, with_table_name(source_row, "event_bot"), "load_id"): + for (t, p), row in schema.normalize_data_item(schema, source_row, "load_id", "event_bot"): row = schema.filter_row(t, row) if not row: # those rows are fully removed @@ -98,7 +97,7 @@ def test_filter_parent_table_schema_update(schema: Schema) -> None: _add_excludes(schema) schema.get_table("event_bot")["filters"]["includes"].extend(["re:^metadata___dlt_", "re:^metadata__elvl1___dlt_"]) schema._compile_regexes() - for (t, p), row in schema.normalize_data_item(schema, with_table_name(source_row, "event_bot"), "load_id"): + for (t, p), row in schema.normalize_data_item(schema, source_row, "load_id", "event_bot"): row = schema.filter_row(t, row) if p is None: assert "_dlt_id" in row diff --git a/tests/common/schema/test_inference.py b/tests/common/schema/test_inference.py index 37c94deded..73a9021def 100644 --- a/tests/common/schema/test_inference.py +++ b/tests/common/schema/test_inference.py @@ -125,7 +125,7 @@ def test_coerce_row(schema: Schema) -> None: schema.update_schema(new_table) with pytest.raises(CannotCoerceColumnException) as exc_val: # now pass the binary that would create binary variant - but the column is occupied by text type - print(schema.coerce_row("event_user", None, {"new_colbool": pendulum.now()})) + schema.coerce_row("event_user", None, {"new_colbool": pendulum.now()}) assert exc_val.value.table_name == "event_user" assert exc_val.value.column_name == "new_colbool__v_timestamp" assert exc_val.value.from_type == "timestamp" @@ -208,7 +208,7 @@ def test_coerce_complex_variant(schema: Schema) -> None: def test_supports_variant_pua_decode(schema: Schema) -> None: rows = load_json_case("pua_encoded_row") - normalized_row = list(schema.normalize_data_item(schema, rows[0], "0912uhj222")) + normalized_row = list(schema.normalize_data_item(schema, rows[0], "0912uhj222", "event")) # pua encoding still present assert normalized_row[0][1]["wad"].startswith("") # decode pua @@ -223,7 +223,7 @@ def test_supports_variant(schema: Schema) -> None: rows = [{"evm": Wei.from_int256(2137*10**16, decimals=18)}, {"evm": Wei.from_int256(2**256-1)}] normalized_rows = [] for row in rows: - normalized_rows.extend(schema.normalize_data_item(schema, row, "128812.2131")) + normalized_rows.extend(schema.normalize_data_item(schema, row, "128812.2131", "event")) # row 1 contains Wei assert isinstance(normalized_rows[0][1]["evm"], Wei) assert normalized_rows[0][1]["evm"] == Wei("21.37") @@ -281,7 +281,7 @@ def __call__(self) -> Any: rows = [{"pv": PureVariant(3377)}, {"pv": PureVariant(21.37)}] normalized_rows = [] for row in rows: - normalized_rows.extend(schema.normalize_data_item(schema, row, "128812.2131")) + normalized_rows.extend(schema.normalize_data_item(schema, row, "128812.2131", "event")) assert normalized_rows[0][1]["pv"]() == 3377 assert normalized_rows[1][1]["pv"]() == ("text", 21.37) # first normalized row fits into schema (pv is int) @@ -338,7 +338,7 @@ def test_infer_with_autodetection(schema: Schema) -> None: def test_update_schema_parent_missing(schema: Schema) -> None: - tab1 = utils.new_table("tab1", parent_name="tab_parent") + tab1 = utils.new_table("tab1", parent_table_name="tab_parent") # tab_parent is missing in schema with pytest.raises(ParentTableNotFoundException) as exc_val: schema.update_schema(tab1) @@ -373,18 +373,15 @@ def test_update_schema_table_prop_conflict(schema: Schema) -> None: # without write disposition will merge del tab1_u2["write_disposition"] schema.update_schema(tab1_u2) - # child table merge checks recursively - child_tab1 = utils.new_table("child_tab", parent_name="tab_parent") - schema.update_schema(child_tab1) - child_tab1_u1 = deepcopy(child_tab1) - # parent table is replace - child_tab1_u1["write_disposition"] = "append" + # tab1 no write disposition, table update has write disposition + tab1["write_disposition"] = None + tab1_u2["write_disposition"] = "merge" + # this will not merge with pytest.raises(TablePropertiesConflictException) as exc_val: - schema.update_schema(child_tab1_u1) - assert exc_val.value.prop_name == "write_disposition" - # this will pass - child_tab1_u1["write_disposition"] = "replace" - schema.update_schema(child_tab1_u1) + schema.update_schema(tab1_u2) + # both write dispositions are None + tab1_u2["write_disposition"] = None + schema.update_schema(tab1_u2) def test_update_schema_column_conflict(schema: Schema) -> None: diff --git a/tests/common/schema/test_schema.py b/tests/common/schema/test_schema.py index 07297cf270..b6fd7abb63 100644 --- a/tests/common/schema/test_schema.py +++ b/tests/common/schema/test_schema.py @@ -2,7 +2,8 @@ import pytest from dlt.common import pendulum -from dlt.common.configuration import SchemaVolumeConfiguration, make_configuration +from dlt.common.configuration import resolve_configuration +from dlt.common.configuration.specs import SchemaVolumeConfiguration from dlt.common.exceptions import DictValidationException from dlt.common.schema.typing import TColumnName, TSimpleRegex, COLUMN_HINTS from dlt.common.typing import DictStrAny, StrAny @@ -11,19 +12,18 @@ from dlt.common.schema.exceptions import InvalidSchemaName, ParentTableNotFoundException, SchemaEngineNoUpgradePathException from dlt.common.storages import SchemaStorage -from tests.utils import autouse_root_storage +from tests.utils import autouse_test_storage from tests.common.utils import load_json_case, load_yml_case SCHEMA_NAME = "event" -EXPECTED_FILE_NAME = f"{SCHEMA_NAME}_schema.json" +EXPECTED_FILE_NAME = f"{SCHEMA_NAME}.schema.json" @pytest.fixture def schema_storage() -> SchemaStorage: - C = make_configuration( - SchemaVolumeConfiguration, - SchemaVolumeConfiguration, - initial_values={ + C = resolve_configuration( + SchemaVolumeConfiguration(), + initial_value={ "import_schema_path": "tests/common/cases/schemas/rasa", "external_schema_format": "json" }) @@ -54,9 +54,9 @@ def cn_schema() -> Schema: def test_normalize_schema_name(schema: Schema) -> None: - assert schema.normalize_schema_name("BAN_ANA") == "banana" - assert schema.normalize_schema_name("event-.!:value") == "eventvalue" - assert schema.normalize_schema_name("123event-.!:value") == "s123eventvalue" + assert schema.normalize_schema_name("BAN_ANA") == "ban_ana" + assert schema.normalize_schema_name("event-.!:value") == "event_value" + assert schema.normalize_schema_name("123event-.!:value") == "_123event_value" with pytest.raises(ValueError): assert schema.normalize_schema_name("") with pytest.raises(ValueError): @@ -117,8 +117,8 @@ def test_column_name_validator(schema: Schema) -> None: def test_invalid_schema_name() -> None: with pytest.raises(InvalidSchemaName) as exc: - Schema("a_b") - assert exc.value.name == "a_b" + Schema("a!b") + assert exc.value.name == "a!b" @pytest.mark.parametrize("columns,hint,value", [ @@ -143,7 +143,7 @@ def test_save_store_schema(schema: Schema, schema_storage: SchemaStorage) -> Non assert not schema_storage.storage.has_file(EXPECTED_FILE_NAME) saved_file_name = schema_storage.save_schema(schema) # return absolute path - assert saved_file_name == schema_storage.storage._make_path(EXPECTED_FILE_NAME) + assert saved_file_name == schema_storage.storage.make_full_path(EXPECTED_FILE_NAME) assert schema_storage.storage.has_file(EXPECTED_FILE_NAME) schema_copy = schema_storage.load_schema("event") assert schema.name == schema_copy.name @@ -158,7 +158,7 @@ def test_save_store_schema_custom_normalizers(cn_schema: Schema, schema_storage: def test_upgrade_engine_v1_schema() -> None: - schema_dict: DictStrAny = load_json_case("schemas/ev1/event_schema") + schema_dict: DictStrAny = load_json_case("schemas/ev1/event.schema") # ensure engine v1 assert schema_dict["engine_version"] == 1 # schema_dict will be updated to new engine version @@ -168,14 +168,14 @@ def test_upgrade_engine_v1_schema() -> None: assert len(schema_dict["tables"]) == 27 # upgrade schema eng 2 -> 4 - schema_dict: DictStrAny = load_json_case("schemas/ev2/event_schema") + schema_dict: DictStrAny = load_json_case("schemas/ev2/event.schema") assert schema_dict["engine_version"] == 2 upgraded = utils.upgrade_engine_version(schema_dict, from_engine=2, to_engine=4) assert upgraded["engine_version"] == 4 utils.validate_stored_schema(upgraded) # upgrade 1 -> 4 - schema_dict: DictStrAny = load_json_case("schemas/ev1/event_schema") + schema_dict: DictStrAny = load_json_case("schemas/ev1/event.schema") assert schema_dict["engine_version"] == 1 upgraded = utils.upgrade_engine_version(schema_dict, from_engine=1, to_engine=4) assert upgraded["engine_version"] == 4 @@ -183,7 +183,7 @@ def test_upgrade_engine_v1_schema() -> None: def test_unknown_engine_upgrade() -> None: - schema_dict: TStoredSchema = load_json_case("schemas/ev1/event_schema") + schema_dict: TStoredSchema = load_json_case("schemas/ev1/event.schema") # there's no path to migrate 3 -> 2 schema_dict["engine_version"] = 3 with pytest.raises(SchemaEngineNoUpgradePathException): @@ -242,7 +242,7 @@ def test_rasa_event_hints(columns: Sequence[str], hint: str, value: bool, schema def test_filter_hints_table() -> None: # this schema contains event_bot table with expected hints - schema_dict: TStoredSchema = load_json_case("schemas/ev1/event_schema") + schema_dict: TStoredSchema = load_json_case("schemas/ev1/event.schema") schema = Schema.from_dict(schema_dict) # get all not_null columns on event bot_case: StrAny = load_json_case("mod_bot_case") @@ -369,16 +369,16 @@ def test_compare_columns() -> None: ]) # columns identical with self for c in table["columns"].values(): - assert utils.compare_columns(c, c) is True - assert utils.compare_columns(table["columns"]["col3"], table["columns"]["col4"]) is True + assert utils.compare_column(c, c) is True + assert utils.compare_column(table["columns"]["col3"], table["columns"]["col4"]) is True # data type may not differ - assert utils.compare_columns(table["columns"]["col1"], table["columns"]["col3"]) is False + assert utils.compare_column(table["columns"]["col1"], table["columns"]["col3"]) is False # nullability may not differ - assert utils.compare_columns(table["columns"]["col1"], table["columns"]["col2"]) is False + assert utils.compare_column(table["columns"]["col1"], table["columns"]["col2"]) is False # any of the hints may differ for hint in COLUMN_HINTS: table["columns"]["col3"][hint] = True - assert utils.compare_columns(table["columns"]["col3"], table["columns"]["col4"]) is True + assert utils.compare_column(table["columns"]["col3"], table["columns"]["col4"]) is True def assert_new_schema_values_custom_normalizers(schema: Schema) -> None: @@ -390,12 +390,12 @@ def assert_new_schema_values_custom_normalizers(schema: Schema) -> None: # call normalizers assert schema.normalize_column_name("a") == "column_a" assert schema.normalize_table_name("a__b") == "A__b" - assert schema.normalize_schema_name("1A_b") == "s1ab" + assert schema.normalize_schema_name("1A_b") == "1a_b" # assumes elements are normalized assert schema.normalize_make_path("A", "B", "!C") == "A__B__!C" assert schema.normalize_break_path("A__B__!C") == ["A", "B", "!C"] - row = list(schema.normalize_data_item(schema, {"bool": True}, "load_id")) - assert row[0] == (("table", None), {"bool": True}) + row = list(schema.normalize_data_item(schema, {"bool": True}, "load_id", "a_table")) + assert row[0] == (("a_table", None), {"bool": True}) def assert_new_schema_values(schema: Schema) -> None: @@ -413,11 +413,11 @@ def assert_new_schema_values(schema: Schema) -> None: # call normalizers assert schema.normalize_column_name("A") == "a" assert schema.normalize_table_name("A__B") == "a__b" - assert schema.normalize_schema_name("1A_b") == "s1ab" + assert schema.normalize_schema_name("1A_b") == "_1_a_b" # assumes elements are normalized assert schema.normalize_make_path("A", "B", "!C") == "A__B__!C" assert schema.normalize_break_path("A__B__!C") == ["A", "B", "!C"] - schema.normalize_data_item(schema, {}, "load_id") + schema.normalize_data_item(schema, {}, "load_id", schema.name) # check default tables tables = schema.tables assert "_dlt_version" in tables diff --git a/tests/common/schema/test_versioning.py b/tests/common/schema/test_versioning.py index b8be8be019..79d23eb417 100644 --- a/tests/common/schema/test_versioning.py +++ b/tests/common/schema/test_versioning.py @@ -115,7 +115,7 @@ def test_version_preserve_on_reload(remove_defaults: bool) -> None: assert saved_schema.stored_version_hash == schema.stored_version_hash # serialize as yaml, for that use a schema that was stored in json - rasa_v4: TStoredSchema = load_json_case("schemas/rasa/event_schema") + rasa_v4: TStoredSchema = load_json_case("schemas/rasa/event.schema") rasa_schema = Schema.from_dict(rasa_v4) rasa_yml = rasa_schema.to_pretty_yaml(remove_defaults=remove_defaults) saved_rasa_schema = Schema.from_dict(yaml.safe_load(rasa_yml)) diff --git a/tests/common/storages/test_file_storage.py b/tests/common/storages/test_file_storage.py index 7299bb777a..b978670b2a 100644 --- a/tests/common/storages/test_file_storage.py +++ b/tests/common/storages/test_file_storage.py @@ -1,10 +1,73 @@ -from dlt.common.file_storage import FileStorage -from dlt.common.utils import encoding_for_mode +import os +import pytest -from tests.utils import TEST_STORAGE +from dlt.common.storages.file_storage import FileStorage +from dlt.common.utils import encoding_for_mode, uniq_id +from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage, test_storage -FileStorage(TEST_STORAGE, makedirs=True) + +# FileStorage(TEST_STORAGE_ROOT, makedirs=True) + + +def test_storage_init(test_storage: FileStorage) -> None: + # must be absolute path + assert os.path.isabs(test_storage.storage_path) + # may not contain file name (ends with / or \) + assert os.path.basename(test_storage.storage_path) == "" + + # TODO: write more cases + + +def test_make_full_path(test_storage: FileStorage) -> None: + # fully within storage + path = test_storage.make_full_path("dir/to/file") + assert path.endswith("/" + TEST_STORAGE_ROOT + "/dir/to/file") + # overlapped with storage + path = test_storage.make_full_path(f"{TEST_STORAGE_ROOT}/dir/to/file") + assert path.endswith("/" + TEST_STORAGE_ROOT + "/dir/to/file") + assert path.count(TEST_STORAGE_ROOT) == 1 + # absolute path with different root than TEST_STORAGE_ROOT + with pytest.raises(ValueError): + test_storage.make_full_path(f"/{TEST_STORAGE_ROOT}/dir/to/file") + # absolute overlapping path + path = test_storage.make_full_path(os.path.abspath(f"{TEST_STORAGE_ROOT}/dir/to/file")) + assert path.endswith("/" + TEST_STORAGE_ROOT + "/dir/to/file") + + +def test_hard_links(test_storage: FileStorage) -> None: + content = uniq_id() + test_storage.save("file.txt", content) + test_storage.link_hard("file.txt", "link.txt") + # it is a file + assert test_storage.has_file("link.txt") + # should have same content as file + assert test_storage.load("link.txt") == content + # should be linked + with test_storage.open_file("file.txt", mode="a") as f: + f.write(content) + assert test_storage.load("link.txt") == content * 2 + with test_storage.open_file("link.txt", mode="a") as f: + f.write(content) + assert test_storage.load("file.txt") == content * 3 + # delete original file + test_storage.delete("file.txt") + assert not test_storage.has_file("file.txt") + assert test_storage.load("link.txt") == content * 3 + + +def test_validate_file_name_component() -> None: + # no dots + with pytest.raises(ValueError): + FileStorage.validate_file_name_component("a.b") + # no slashes + with pytest.raises(ValueError): + FileStorage.validate_file_name_component("a/b") + # no backslashes + with pytest.raises(ValueError): + FileStorage.validate_file_name_component("a\\b") + + FileStorage.validate_file_name_component("BAN__ANA is allowed") def test_encoding_for_mode() -> None: @@ -17,15 +80,15 @@ def test_encoding_for_mode() -> None: def test_save_atomic_encode() -> None: tstr = "data'ऄअआइ''ईउऊऋऌऍऎए');" - FileStorage.save_atomic(TEST_STORAGE, "file.txt", tstr) - storage = FileStorage(TEST_STORAGE) + FileStorage.save_atomic(TEST_STORAGE_ROOT, "file.txt", tstr) + storage = FileStorage(TEST_STORAGE_ROOT) with storage.open_file("file.txt") as f: assert f.encoding == "utf-8" assert f.read() == tstr bstr = b"axa\0x0\0x0" - FileStorage.save_atomic(TEST_STORAGE, "file.bin", bstr, file_type="b") - storage = FileStorage(TEST_STORAGE, file_type="b") + FileStorage.save_atomic(TEST_STORAGE_ROOT, "file.bin", bstr, file_type="b") + storage = FileStorage(TEST_STORAGE_ROOT, file_type="b") with storage.open_file("file.bin", mode="r") as f: assert hasattr(f, "encoding") is False assert f.read() == bstr diff --git a/tests/common/storages/test_loader_storage.py b/tests/common/storages/test_loader_storage.py index a829f6207f..980364946e 100644 --- a/tests/common/storages/test_loader_storage.py +++ b/tests/common/storages/test_loader_storage.py @@ -6,24 +6,25 @@ from dlt.common import sleep from dlt.common.schema import Schema from dlt.common.storages.load_storage import LoadStorage, TParsedJobFileName -from dlt.common.configuration import LoadVolumeConfiguration, make_configuration +from dlt.common.configuration import resolve_configuration +from dlt.common.configuration.specs import LoadVolumeConfiguration from dlt.common.storages.exceptions import NoMigrationPathException from dlt.common.typing import StrAny from dlt.common.utils import uniq_id -from tests.utils import TEST_STORAGE, write_version, autouse_root_storage +from tests.utils import TEST_STORAGE_ROOT, write_version, autouse_test_storage @pytest.fixture def storage() -> LoadStorage: - C = make_configuration(LoadVolumeConfiguration, LoadVolumeConfiguration) - s = LoadStorage(True, C, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) + C = resolve_configuration(LoadVolumeConfiguration()) + s = LoadStorage(True, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS, C) return s def test_complete_successful_package(storage: LoadStorage) -> None: # should delete package in full - storage.delete_completed_jobs = True + storage.config.delete_completed_jobs = True load_id, file_name = start_loading_file(storage, [{"content": "a"}, {"content": "b"}]) assert storage.storage.has_folder(storage.get_package_path(load_id)) storage.complete_job(load_id, file_name) @@ -34,7 +35,7 @@ def test_complete_successful_package(storage: LoadStorage) -> None: assert not storage.storage.has_folder(storage.get_completed_package_path(load_id)) # do not delete completed jobs - storage.delete_completed_jobs = False + storage.config.delete_completed_jobs = False load_id, file_name = start_loading_file(storage, [{"content": "a"}, {"content": "b"}]) storage.complete_job(load_id, file_name) storage.complete_load_package(load_id) @@ -46,7 +47,7 @@ def test_complete_successful_package(storage: LoadStorage) -> None: def test_complete_package_failed_jobs(storage: LoadStorage) -> None: # loads with failed jobs are always persisted - storage.delete_completed_jobs = True + storage.config.delete_completed_jobs = True load_id, file_name = start_loading_file(storage, [{"content": "a"}, {"content": "b"}]) assert storage.storage.has_folder(storage.get_package_path(load_id)) storage.fail_job(load_id, file_name, "EXCEPTION") @@ -72,7 +73,7 @@ def test_save_load_schema(storage: LoadStorage) -> None: def test_job_elapsed_time_seconds(storage: LoadStorage) -> None: load_id, fn = start_loading_file(storage, "test file") - fp = storage.storage._make_path(storage._get_job_file_path(load_id, "started_jobs", fn)) + fp = storage.storage.make_full_path(storage._get_job_file_path(load_id, "started_jobs", fn)) elapsed = storage.job_elapsed_time_seconds(fp) sleep(0.3) # do not touch file @@ -141,22 +142,22 @@ def test_process_schema_update(storage: LoadStorage) -> None: def test_full_migration_path() -> None: # create directory structure - s = LoadStorage(True, LoadVolumeConfiguration, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) + s = LoadStorage(True, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) # overwrite known initial version write_version(s.storage, "1.0.0") # must be able to migrate to current version - s = LoadStorage(False, LoadVolumeConfiguration, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) + s = LoadStorage(False, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) assert s.version == LoadStorage.STORAGE_VERSION def test_unknown_migration_path() -> None: # create directory structure - s = LoadStorage(True, LoadVolumeConfiguration, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) + s = LoadStorage(True, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) # overwrite known initial version write_version(s.storage, "10.0.0") # must be able to migrate to current version with pytest.raises(NoMigrationPathException): - LoadStorage(False, LoadVolumeConfiguration, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) + LoadStorage(False, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) def start_loading_file(s: LoadStorage, content: Sequence[StrAny]) -> Tuple[str, str]: diff --git a/tests/common/storages/test_normalize_storage.py b/tests/common/storages/test_normalize_storage.py index e1eb8552bd..8deb472140 100644 --- a/tests/common/storages/test_normalize_storage.py +++ b/tests/common/storages/test_normalize_storage.py @@ -1,11 +1,13 @@ import pytest -from dlt.common.storages.exceptions import NoMigrationPathException -from dlt.common.storages.normalize_storage import NormalizeStorage -from dlt.common.configuration import NormalizeVolumeConfiguration from dlt.common.utils import uniq_id +from dlt.common.storages import NormalizeStorage +from dlt.common.storages.exceptions import NoMigrationPathException +from dlt.common.configuration.specs import NormalizeVolumeConfiguration +from dlt.common.storages.normalize_storage import TParsedNormalizeFileName + +from tests.utils import write_version, autouse_test_storage -from tests.utils import write_version, autouse_root_storage @pytest.mark.skip() def test_load_events_and_group_by_sender() -> None: @@ -13,39 +15,32 @@ def test_load_events_and_group_by_sender() -> None: pass -@pytest.mark.skip() -def test_chunk_by_events() -> None: - # TODO: should distribute ~ N events evenly among m cores with fallback for small amounts of events - pass - - def test_build_extracted_file_name() -> None: load_id = uniq_id() - name = NormalizeStorage.build_extracted_file_name("event", "table", 121, load_id) + name = NormalizeStorage.build_extracted_file_stem("event", "table_with_parts__many", load_id) + ".jsonl" assert NormalizeStorage.get_schema_name(name) == "event" - assert NormalizeStorage.get_events_count(name) == 121 - assert NormalizeStorage._parse_extracted_file_name(name) == (121, load_id, "event") + assert NormalizeStorage.parse_normalize_file_name(name) == TParsedNormalizeFileName("event", "table_with_parts__many", load_id) # empty schema should be supported - name = NormalizeStorage.build_extracted_file_name("", "table", 121, load_id) - assert NormalizeStorage._parse_extracted_file_name(name) == (121, load_id, "") + name = NormalizeStorage.build_extracted_file_stem("", "table", load_id) + ".jsonl" + assert NormalizeStorage.parse_normalize_file_name(name) == TParsedNormalizeFileName("", "table", load_id) def test_full_migration_path() -> None: # create directory structure - s = NormalizeStorage(True, NormalizeVolumeConfiguration) + s = NormalizeStorage(True) # overwrite known initial version write_version(s.storage, "1.0.0") # must be able to migrate to current version - s = NormalizeStorage(True, NormalizeVolumeConfiguration) + s = NormalizeStorage(True) assert s.version == NormalizeStorage.STORAGE_VERSION def test_unknown_migration_path() -> None: # create directory structure - s = NormalizeStorage(True, NormalizeVolumeConfiguration) + s = NormalizeStorage(True) # overwrite known initial version write_version(s.storage, "10.0.0") # must be able to migrate to current version with pytest.raises(NoMigrationPathException): - NormalizeStorage(False, NormalizeVolumeConfiguration) + NormalizeStorage(False) diff --git a/tests/common/storages/test_schema_storage.py b/tests/common/storages/test_schema_storage.py index 28755c6acb..341fdfd2e5 100644 --- a/tests/common/storages/test_schema_storage.py +++ b/tests/common/storages/test_schema_storage.py @@ -4,45 +4,42 @@ import yaml from dlt.common import json -from dlt.common.configuration import make_configuration -from dlt.common.file_storage import FileStorage from dlt.common.schema.schema import Schema from dlt.common.schema.typing import TStoredSchema from dlt.common.schema.utils import default_normalizers -from dlt.common.configuration import SchemaVolumeConfiguration +from dlt.common.configuration.specs import SchemaVolumeConfiguration from dlt.common.storages.exceptions import InStorageSchemaModified, SchemaNotFoundError -from dlt.common.storages import SchemaStorage, LiveSchemaStorage -from dlt.common.typing import DictStrAny +from dlt.common.storages import SchemaStorage, LiveSchemaStorage, FileStorage -from tests.utils import autouse_root_storage, TEST_STORAGE +from tests.utils import autouse_test_storage, TEST_STORAGE_ROOT from tests.common.utils import load_yml_case, yml_case_path @pytest.fixture def storage() -> SchemaStorage: - return init_storage() + return init_storage(SchemaVolumeConfiguration()) @pytest.fixture def synced_storage() -> SchemaStorage: # will be created in /schemas - return init_storage({"IMPORT_SCHEMA_PATH": TEST_STORAGE + "/import", "EXPORT_SCHEMA_PATH": TEST_STORAGE + "/import"}) + return init_storage(SchemaVolumeConfiguration(import_schema_path=TEST_STORAGE_ROOT + "/import", export_schema_path=TEST_STORAGE_ROOT + "/import")) @pytest.fixture def ie_storage() -> SchemaStorage: # will be created in /schemas - return init_storage({"IMPORT_SCHEMA_PATH": TEST_STORAGE + "/import", "EXPORT_SCHEMA_PATH": TEST_STORAGE + "/export"}) + return init_storage(SchemaVolumeConfiguration(import_schema_path=TEST_STORAGE_ROOT + "/import", export_schema_path=TEST_STORAGE_ROOT + "/export")) -def init_storage(initial: DictStrAny = None) -> SchemaStorage: - C = make_configuration(SchemaVolumeConfiguration, SchemaVolumeConfiguration, initial_values=initial) +def init_storage(C: SchemaVolumeConfiguration) -> SchemaStorage: # use live schema storage for test which must be backward compatible with schema storage s = LiveSchemaStorage(C, makedirs=True) - if C.EXPORT_SCHEMA_PATH: - os.makedirs(C.EXPORT_SCHEMA_PATH, exist_ok=True) - if C.IMPORT_SCHEMA_PATH: - os.makedirs(C.IMPORT_SCHEMA_PATH, exist_ok=True) + assert C is s.config + if C.export_schema_path: + os.makedirs(C.export_schema_path, exist_ok=True) + if C.import_schema_path: + os.makedirs(C.import_schema_path, exist_ok=True) return s @@ -86,7 +83,7 @@ def test_skip_import_if_not_modified(synced_storage: SchemaStorage, storage: Sch # the import schema gets modified storage_schema.tables["_dlt_loads"]["write_disposition"] = "append" storage_schema.tables.pop("event_user") - synced_storage._export_schema(storage_schema, synced_storage.C.EXPORT_SCHEMA_PATH) + synced_storage._export_schema(storage_schema, synced_storage.config.export_schema_path) # now load will import again reloaded_schema = synced_storage.load_schema("ethereum") # we have overwritten storage schema @@ -113,7 +110,7 @@ def test_store_schema_tampered(synced_storage: SchemaStorage, storage: SchemaSto def test_schema_export(ie_storage: SchemaStorage) -> None: schema = Schema("ethereum") - fs = FileStorage(ie_storage.C.EXPORT_SCHEMA_PATH) + fs = FileStorage(ie_storage.config.export_schema_path) exported_name = ie_storage._file_name_in_store("ethereum", "yaml") # no exported schema assert not fs.has_file(exported_name) @@ -135,6 +132,10 @@ def test_list_schemas(storage: SchemaStorage) -> None: assert set(storage.list_schemas()) == set(["ethereum", "event"]) storage.remove_schema("event") assert storage.list_schemas() == ["ethereum"] + # add schema with _ in the name + schema = Schema("dlt_pipeline") + storage.save_schema(schema) + assert set(storage.list_schemas()) == set(["ethereum", "dlt_pipeline"]) def test_remove_schema(storage: SchemaStorage) -> None: @@ -192,7 +193,7 @@ def test_save_store_schema_over_import(ie_storage: SchemaStorage) -> None: assert schema.version_hash == schema_hash assert schema._imported_version_hash == "njJAySgJRs2TqGWgQXhP+3pCh1A1hXcqe77BpM7JtOU=" # we have simple schema in export folder - fs = FileStorage(ie_storage.C.EXPORT_SCHEMA_PATH) + fs = FileStorage(ie_storage.config.export_schema_path) exported_name = ie_storage._file_name_in_store("ethereum", "yaml") exported_schema = yaml.safe_load(fs.load(exported_name)) assert schema.version_hash == exported_schema["version_hash"] @@ -206,7 +207,7 @@ def test_save_store_schema_over_import_sync(synced_storage: SchemaStorage) -> No synced_storage.save_schema(schema) assert schema._imported_version_hash == "njJAySgJRs2TqGWgQXhP+3pCh1A1hXcqe77BpM7JtOU=" # import schema is overwritten - fs = FileStorage(synced_storage.C.IMPORT_SCHEMA_PATH) + fs = FileStorage(synced_storage.config.import_schema_path) exported_name = synced_storage._file_name_in_store("ethereum", "yaml") exported_schema = yaml.safe_load(fs.load(exported_name)) assert schema.version_hash == exported_schema["version_hash"] == schema_hash @@ -242,7 +243,7 @@ def test_save_store_schema(storage: SchemaStorage) -> None: def prepare_import_folder(storage: SchemaStorage) -> None: - shutil.copy(yml_case_path("schemas/eth/ethereum_schema_v4"), storage.storage._make_path("../import/ethereum_schema.yaml")) + shutil.copy(yml_case_path("schemas/eth/ethereum_schema_v4"), storage.storage.make_full_path("../import/ethereum.schema.yaml")) def assert_schema_imported(synced_storage: SchemaStorage, storage: SchemaStorage) -> Schema: diff --git a/tests/common/storages/test_versioned_storage.py b/tests/common/storages/test_versioned_storage.py index e4bcbf7a37..ff23480a48 100644 --- a/tests/common/storages/test_versioned_storage.py +++ b/tests/common/storages/test_versioned_storage.py @@ -1,11 +1,11 @@ import pytest import semver -from dlt.common.file_storage import FileStorage +from dlt.common.storages.file_storage import FileStorage from dlt.common.storages.exceptions import NoMigrationPathException, WrongStorageVersionException from dlt.common.storages.versioned_storage import VersionedStorage -from tests.utils import write_version, root_storage +from tests.utils import write_version, test_storage class MigratedStorage(VersionedStorage): @@ -19,41 +19,41 @@ def migrate_storage(self, from_version: semver.VersionInfo, to_version: semver.V self._save_version(from_version) -def test_new_versioned_storage(root_storage: FileStorage) -> None: - v = VersionedStorage("1.0.1", True, root_storage) +def test_new_versioned_storage(test_storage: FileStorage) -> None: + v = VersionedStorage("1.0.1", True, test_storage) assert v.version == "1.0.1" -def test_new_versioned_storage_non_owner(root_storage: FileStorage) -> None: +def test_new_versioned_storage_non_owner(test_storage: FileStorage) -> None: with pytest.raises(WrongStorageVersionException) as wsve: - VersionedStorage("1.0.1", False, root_storage) - assert wsve.value.storage_path == root_storage.storage_path + VersionedStorage("1.0.1", False, test_storage) + assert wsve.value.storage_path == test_storage.storage_path assert wsve.value.target_version == "1.0.1" assert wsve.value.initial_version == "0.0.0" -def test_migration(root_storage: FileStorage) -> None: - write_version(root_storage, "1.0.0") - v = MigratedStorage("1.2.0", True, root_storage) +def test_migration(test_storage: FileStorage) -> None: + write_version(test_storage, "1.0.0") + v = MigratedStorage("1.2.0", True, test_storage) assert v.version == "1.2.0" -def test_unknown_migration_path(root_storage: FileStorage) -> None: - write_version(root_storage, "1.0.0") +def test_unknown_migration_path(test_storage: FileStorage) -> None: + write_version(test_storage, "1.0.0") with pytest.raises(NoMigrationPathException) as wmpe: - MigratedStorage("1.3.0", True, root_storage) + MigratedStorage("1.3.0", True, test_storage) assert wmpe.value.migrated_version == "1.2.0" -def test_only_owner_migrates(root_storage: FileStorage) -> None: - write_version(root_storage, "1.0.0") +def test_only_owner_migrates(test_storage: FileStorage) -> None: + write_version(test_storage, "1.0.0") with pytest.raises(WrongStorageVersionException) as wmpe: - MigratedStorage("1.2.0", False, root_storage) + MigratedStorage("1.2.0", False, test_storage) assert wmpe.value.initial_version == "1.0.0" -def test_downgrade_not_possible(root_storage: FileStorage) -> None: - write_version(root_storage, "1.2.0") +def test_downgrade_not_possible(test_storage: FileStorage) -> None: + write_version(test_storage, "1.2.0") with pytest.raises(NoMigrationPathException) as wmpe: - MigratedStorage("1.1.0", True, root_storage) + MigratedStorage("1.1.0", True, test_storage) assert wmpe.value.migrated_version == "1.2.0" \ No newline at end of file diff --git a/tests/common/test_configuration.py b/tests/common/test_configuration.py deleted file mode 100644 index ca9da8f928..0000000000 --- a/tests/common/test_configuration.py +++ /dev/null @@ -1,463 +0,0 @@ -import pytest -from os import environ -from typing import Any, Dict, List, NewType, Optional, Tuple - -from dlt.common.typing import TSecretValue -from dlt.common.configuration import ( - RunConfiguration, ConfigEntryMissingException, ConfigFileNotFoundException, - ConfigEnvValueCannotBeCoercedException, BaseConfiguration, utils) -from dlt.common.configuration.utils import (_coerce_single_value, IS_DEVELOPMENT_CONFIG_KEY, - _get_config_attrs_with_hints, - is_direct_descendant, make_configuration) -from dlt.common.configuration.providers import environ as environ_provider - -from tests.utils import preserve_environ - -# used to test version -__version__ = "1.0.5" - -IS_DEVELOPMENT_CONFIG = 'DEBUG' -NONE_CONFIG_VAR = 'NoneConfigVar' -COERCIONS = { - 'STR_VAL': 'test string', - 'INT_VAL': 12345, - 'BOOL_VAL': True, - 'LIST_VAL': [1, "2", [3]], - 'DICT_VAL': { - 'a': 1, - "b": "2" - }, - 'TUPLE_VAL': (1, 2, '7'), - 'SET_VAL': {1, 2, 3}, - 'BYTES_VAL': b'Hello World!', - 'FLOAT_VAL': 1.18927, - 'ANY_VAL': "function() {}", - 'NONE_VAL': "none", - 'COMPLEX_VAL': { - "_": (1440, ["*"], []), - "change-email": (560, ["*"], []) - } -} - -INVALID_COERCIONS = { - # 'STR_VAL': 'test string', # string always OK - 'INT_VAL': "a12345", - 'BOOL_VAL': "Yes", # bool overridden by string - that is the most common problem - 'LIST_VAL': {1, "2", 3.0}, - 'DICT_VAL': "{'a': 1, 'b', '2'}", - 'TUPLE_VAL': [1, 2, '7'], - 'SET_VAL': [1, 2, 3], - 'BYTES_VAL': 'Hello World!', - 'FLOAT_VAL': "invalid" -} - -EXCEPTED_COERCIONS = { - # allows to use int for float - 'FLOAT_VAL': 10, - # allows to use float for str - 'STR_VAL': 10.0 -} - -COERCED_EXCEPTIONS = { - # allows to use int for float - 'FLOAT_VAL': 10.0, - # allows to use float for str - 'STR_VAL': "10.0" -} - - -class SimpleConfiguration(RunConfiguration): - PIPELINE_NAME: str = "Some Name" - - -class WrongConfiguration(RunConfiguration): - PIPELINE_NAME: str = "Some Name" - NoneConfigVar = None - LOG_COLOR: bool = True - - -class SecretConfiguration(RunConfiguration): - PIPELINE_NAME: str = "secret" - SECRET_VALUE: TSecretValue = None - - -class SecretKubeConfiguration(RunConfiguration): - PIPELINE_NAME: str = "secret kube" - SECRET_KUBE: TSecretValue = None - - -class TestCoercionConfiguration(RunConfiguration): - PIPELINE_NAME: str = "Some Name" - STR_VAL: str = None - INT_VAL: int = None - BOOL_VAL: bool = None - LIST_VAL: list = None # type: ignore - DICT_VAL: dict = None # type: ignore - TUPLE_VAL: tuple = None # type: ignore - BYTES_VAL: bytes = None - SET_VAL: set = None # type: ignore - FLOAT_VAL: float = None - ANY_VAL: Any = None - NONE_VAL = None - COMPLEX_VAL: Dict[str, Tuple[int, List[str], List[str]]] = None - - -class VeryWrongConfiguration(WrongConfiguration): - PIPELINE_NAME: str = "Some Name" - STR_VAL: str = "" - INT_VAL: int = None - LOG_COLOR: str = "1" # type: ignore - - -class ConfigurationWithOptionalTypes(RunConfiguration): - PIPELINE_NAME: str = "Some Name" - - STR_VAL: Optional[str] = None - INT_VAL: Optional[int] = None - BOOL_VAL: bool = True - - -class ProdConfigurationWithOptionalTypes(ConfigurationWithOptionalTypes): - PROD_VAL: str = "prod" - - -class MockProdConfiguration(RunConfiguration): - PIPELINE_NAME: str = "comp" - - -class MockProdConfigurationVar(RunConfiguration): - PIPELINE_NAME: str = "comp" - - -class NamespacedConfiguration(BaseConfiguration): - __namespace__ = "DLT_TEST" - - PASSWORD: str = None - - -LongInteger = NewType("LongInteger", int) -FirstOrderStr = NewType("FirstOrderStr", str) -SecondOrderStr = NewType("SecondOrderStr", FirstOrderStr) - - -@pytest.fixture(scope="function") -def environment() -> Any: - environ.clear() - return environ - - -def test_run_configuration_gen_name(environment: Any) -> None: - C = make_configuration(RunConfiguration, RunConfiguration) - assert C.PIPELINE_NAME.startswith("dlt_") - - -def test_configuration_to_dict(environment: Any) -> None: - expected_dict = { - 'CONFIG_FILES_STORAGE_PATH': '_storage/config/%s', - 'IS_DEVELOPMENT_CONFIG': True, - 'LOG_FORMAT': '{asctime}|[{levelname:<21}]|{process}|{name}|{filename}|{funcName}:{lineno}|{message}', - 'LOG_LEVEL': 'DEBUG', - 'PIPELINE_NAME': 'secret', - 'PROMETHEUS_PORT': None, - 'REQUEST_TIMEOUT': (15, 300), - 'SECRET_VALUE': None, - 'SENTRY_DSN': None - } - assert SecretConfiguration.as_dict() == {k.lower():v for k,v in expected_dict.items()} - assert SecretConfiguration.as_dict(lowercase=False) == expected_dict - - environment["SECRET_VALUE"] = "secret" - C = make_configuration(SecretConfiguration, SecretConfiguration) - d = C.as_dict(lowercase=False) - expected_dict["_VERSION"] = d["_VERSION"] - expected_dict["SECRET_VALUE"] = "secret" - assert d == expected_dict - - -def test_configuration_rise_exception_when_config_is_not_complete() -> None: - with pytest.raises(ConfigEntryMissingException) as config_entry_missing_exception: - keys = _get_config_attrs_with_hints(WrongConfiguration) - utils._is_config_bounded(WrongConfiguration, keys) - - assert 'NoneConfigVar' in config_entry_missing_exception.value.missing_set - - -def test_optional_types_are_not_required() -> None: - # this should not raise an exception - keys = _get_config_attrs_with_hints(ConfigurationWithOptionalTypes) - utils._is_config_bounded(ConfigurationWithOptionalTypes, keys) - # make optional config - make_configuration(ConfigurationWithOptionalTypes, ConfigurationWithOptionalTypes) - # make config with optional values - make_configuration( - ProdConfigurationWithOptionalTypes, - ProdConfigurationWithOptionalTypes, - initial_values={"INT_VAL": None} - ) - - -def test_configuration_apply_adds_environment_variable_to_config(environment: Any) -> None: - environment[NONE_CONFIG_VAR] = "Some" - - keys = _get_config_attrs_with_hints(WrongConfiguration) - utils._apply_environ_to_config(WrongConfiguration, keys) - utils._is_config_bounded(WrongConfiguration, keys) - - # NoneConfigVar has no hint so value not coerced from string - assert WrongConfiguration.NoneConfigVar == environment[NONE_CONFIG_VAR] - - -def test_configuration_resolve(environment: Any) -> None: - environment[IS_DEVELOPMENT_CONFIG] = 'True' - - keys = _get_config_attrs_with_hints(SimpleConfiguration) - utils._apply_environ_to_config(SimpleConfiguration, keys) - utils._is_config_bounded(SimpleConfiguration, keys) - - # value will be coerced to bool - assert RunConfiguration.IS_DEVELOPMENT_CONFIG is True - - -def test_find_all_keys() -> None: - keys = _get_config_attrs_with_hints(VeryWrongConfiguration) - # assert hints and types: NoneConfigVar has no type hint and LOG_COLOR had it hint overwritten in derived class - assert set({'STR_VAL': str, 'INT_VAL': int, 'NoneConfigVar': None, 'LOG_COLOR': str}.items()).issubset(keys.items()) - - -def test_coercions(environment: Any) -> None: - for key, value in COERCIONS.items(): - environment[key] = str(value) - - keys = _get_config_attrs_with_hints(TestCoercionConfiguration) - utils._apply_environ_to_config(TestCoercionConfiguration, keys) - utils._is_config_bounded(TestCoercionConfiguration, keys) - - for key in COERCIONS: - assert getattr(TestCoercionConfiguration, key) == COERCIONS[key] - - -def test_invalid_coercions(environment: Any) -> None: - config_keys = _get_config_attrs_with_hints(TestCoercionConfiguration) - for key, value in INVALID_COERCIONS.items(): - try: - environment[key] = str(value) - utils._apply_environ_to_config(TestCoercionConfiguration, config_keys) - except ConfigEnvValueCannotBeCoercedException as coerc_exc: - # must fail excatly on expected value - if coerc_exc.attr_name != key: - raise - # overwrite with valid value and go to next env - environment[key] = str(COERCIONS[key]) - continue - raise AssertionError("%s was coerced with %s which is invalid type" % (key, value)) - - -def test_excepted_coercions(environment: Any) -> None: - config_keys = _get_config_attrs_with_hints(TestCoercionConfiguration) - for k, v in EXCEPTED_COERCIONS.items(): - environment[k] = str(v) - utils._apply_environ_to_config(TestCoercionConfiguration, config_keys) - for key in EXCEPTED_COERCIONS: - assert getattr(TestCoercionConfiguration, key) == COERCED_EXCEPTIONS[key] - - -def test_development_config_detection(environment: Any) -> None: - # default is true - assert utils._is_development_config() - environment[IS_DEVELOPMENT_CONFIG_KEY] = "False" - # explicit values - assert not utils._is_development_config() - environment[IS_DEVELOPMENT_CONFIG_KEY] = "True" - assert utils._is_development_config() - # raise exception on env value that cannot be coerced to bool - with pytest.raises(ConfigEnvValueCannotBeCoercedException): - environment[IS_DEVELOPMENT_CONFIG_KEY] = "NONBOOL" - utils._is_development_config() - - -def test_make_configuration(environment: Any) -> None: - # fill up configuration - environment['INT_VAL'] = "1" - C = utils.make_configuration(WrongConfiguration, VeryWrongConfiguration) - assert not C.__is_partial__ - # default is true - assert is_direct_descendant(C, WrongConfiguration) - environment[IS_DEVELOPMENT_CONFIG_KEY] = "False" - assert is_direct_descendant(utils.make_configuration(WrongConfiguration, VeryWrongConfiguration), VeryWrongConfiguration) - environment[IS_DEVELOPMENT_CONFIG_KEY] = "True" - assert is_direct_descendant(utils.make_configuration(WrongConfiguration, VeryWrongConfiguration), WrongConfiguration) - - -def test_auto_derivation(environment: Any) -> None: - # make_configuration auto derives a type and never modifies the original type - environment['SECRET_VALUE'] = "1" - C = utils.make_configuration(SecretConfiguration, SecretConfiguration) - # auto derived type holds the value - assert C.SECRET_VALUE == "1" - # base type is untouched - assert SecretConfiguration.SECRET_VALUE is None - # type name is derived - assert C.__name__.startswith("SecretConfiguration_") - - -def test_initial_values(environment: Any) -> None: - # initial values will be overridden from env - environment["PIPELINE_NAME"] = "env name" - environment["CREATED_VAL"] = "12837" - # set initial values and allow partial config - C = make_configuration(TestCoercionConfiguration, TestCoercionConfiguration, - {"PIPELINE_NAME": "initial name", "NONE_VAL": type(environment), "CREATED_VAL": 878232, "BYTES_VAL": b"str"}, - accept_partial=True - ) - # from env - assert C.PIPELINE_NAME == "env name" - # from initial - assert C.BYTES_VAL == b"str" - assert C.NONE_VAL == type(environment) - # new prop overridden from env - assert environment["CREATED_VAL"] == "12837" - - -def test_accept_partial(environment: Any) -> None: - WrongConfiguration.NoneConfigVar = None - C = make_configuration(WrongConfiguration, WrongConfiguration, accept_partial=True) - assert C.NoneConfigVar is None - # partial resolution - assert C.__is_partial__ - - -def test_finds_version(environment: Any) -> None: - global __version__ - - v = __version__ - C = utils.make_configuration(SimpleConfiguration, SimpleConfiguration) - assert C._VERSION == v - try: - del globals()["__version__"] - # C is a type, not instance and holds the _VERSION from previous extract - delattr(C, "_VERSION") - C = utils.make_configuration(SimpleConfiguration, SimpleConfiguration) - assert not hasattr(C, "_VERSION") - finally: - __version__ = v - - -def test_secret(environment: Any) -> None: - with pytest.raises(ConfigEntryMissingException): - utils.make_configuration(SecretConfiguration, SecretConfiguration) - environment['SECRET_VALUE'] = "1" - C = utils.make_configuration(SecretConfiguration, SecretConfiguration) - assert C.SECRET_VALUE == "1" - # mock the path to point to secret storage - # from dlt.common.configuration import config_utils - path = environ_provider.SECRET_STORAGE_PATH - del environment['SECRET_VALUE'] - try: - # must read a secret file - environ_provider.SECRET_STORAGE_PATH = "./tests/common/cases/%s" - C = utils.make_configuration(SecretConfiguration, SecretConfiguration) - assert C.SECRET_VALUE == "BANANA" - - # set some weird path, no secret file at all - del environment['SECRET_VALUE'] - environ_provider.SECRET_STORAGE_PATH = "!C:\\PATH%s" - with pytest.raises(ConfigEntryMissingException): - utils.make_configuration(SecretConfiguration, SecretConfiguration) - - # set env which is a fallback for secret not as file - environment['SECRET_VALUE'] = "1" - C = utils.make_configuration(SecretConfiguration, SecretConfiguration) - assert C.SECRET_VALUE == "1" - finally: - environ_provider.SECRET_STORAGE_PATH = path - - -def test_secret_kube_fallback(environment: Any) -> None: - path = environ_provider.SECRET_STORAGE_PATH - try: - environ_provider.SECRET_STORAGE_PATH = "./tests/common/cases/%s" - C = utils.make_configuration(SecretKubeConfiguration, SecretKubeConfiguration) - # all unix editors will add x10 at the end of file, it will be preserved - assert C.SECRET_KUBE == "kube\n" - # we propagate secrets back to environ and strip the whitespace - assert environment['SECRET_KUBE'] == "kube" - finally: - environ_provider.SECRET_STORAGE_PATH = path - - -def test_configuration_must_be_subclass_of_prod(environment: Any) -> None: - # fill up configuration - environment['INT_VAL'] = "1" - # prod must inherit from config - with pytest.raises(AssertionError): - # VeryWrongConfiguration does not descend inherit from ConfigurationWithOptionalTypes so it cannot be production config of it - utils.make_configuration(ConfigurationWithOptionalTypes, VeryWrongConfiguration) - - -def test_coerce_values() -> None: - with pytest.raises(ConfigEnvValueCannotBeCoercedException): - _coerce_single_value("key", "some string", int) - assert _coerce_single_value("key", "some string", str) == "some string" - # Optional[str] has type object, mypy will never work properly... - assert _coerce_single_value("key", "some string", Optional[str]) == "some string" # type: ignore - - assert _coerce_single_value("key", "234", int) == 234 - assert _coerce_single_value("key", "234", Optional[int]) == 234 # type: ignore - - # check coercions of NewTypes - assert _coerce_single_value("key", "test str X", FirstOrderStr) == "test str X" - assert _coerce_single_value("key", "test str X", Optional[FirstOrderStr]) == "test str X" # type: ignore - assert _coerce_single_value("key", "test str X", Optional[SecondOrderStr]) == "test str X" # type: ignore - assert _coerce_single_value("key", "test str X", SecondOrderStr) == "test str X" - assert _coerce_single_value("key", "234", LongInteger) == 234 - assert _coerce_single_value("key", "234", Optional[LongInteger]) == 234 # type: ignore - # this coercion should fail - with pytest.raises(ConfigEnvValueCannotBeCoercedException): - _coerce_single_value("key", "some string", LongInteger) - with pytest.raises(ConfigEnvValueCannotBeCoercedException): - _coerce_single_value("key", "some string", Optional[LongInteger]) # type: ignore - - -def test_configuration_files_prod_path(environment: Any) -> None: - environment[IS_DEVELOPMENT_CONFIG_KEY] = "True" - C = utils.make_configuration(MockProdConfiguration, MockProdConfiguration) - assert C.CONFIG_FILES_STORAGE_PATH == "_storage/config/%s" - - environment[IS_DEVELOPMENT_CONFIG_KEY] = "False" - C = utils.make_configuration(MockProdConfiguration, MockProdConfiguration) - assert C.IS_DEVELOPMENT_CONFIG is False - assert C.CONFIG_FILES_STORAGE_PATH == "/run/config/%s" - - -def test_configuration_files(environment: Any) -> None: - # overwrite config file paths - environment[IS_DEVELOPMENT_CONFIG_KEY] = "False" - environment["CONFIG_FILES_STORAGE_PATH"] = "./tests/common/cases/schemas/ev1/%s" - C = utils.make_configuration(MockProdConfigurationVar, MockProdConfigurationVar) - assert C.CONFIG_FILES_STORAGE_PATH == environment["CONFIG_FILES_STORAGE_PATH"] - assert C.has_configuration_file("hasn't") is False - assert C.has_configuration_file("event_schema.json") is True - assert C.get_configuration_file_path("event_schema.json") == "./tests/common/cases/schemas/ev1/event_schema.json" - with C.open_configuration_file("event_schema.json", "r") as f: - f.read() - with pytest.raises(ConfigFileNotFoundException): - C.open_configuration_file("hasn't", "r") - - -def test_namespaced_configuration(environment: Any) -> None: - with pytest.raises(ConfigEntryMissingException) as exc_val: - utils.make_configuration(NamespacedConfiguration, NamespacedConfiguration) - assert exc_val.value.missing_set == ["DLT_TEST__PASSWORD"] - assert exc_val.value.namespace == "DLT_TEST" - - # init vars work without namespace - C = utils.make_configuration(NamespacedConfiguration, NamespacedConfiguration, initial_values={"PASSWORD": "PASS"}) - assert C.PASSWORD == "PASS" - - # env var must be prefixed - environment["PASSWORD"] = "PASS" - with pytest.raises(ConfigEntryMissingException) as exc_val: - utils.make_configuration(NamespacedConfiguration, NamespacedConfiguration) - environment["DLT_TEST__PASSWORD"] = "PASS" - C = utils.make_configuration(NamespacedConfiguration, NamespacedConfiguration) - assert C.PASSWORD == "PASS" diff --git a/tests/common/test_data_writers.py b/tests/common/test_data_writers.py index 439bae0748..456643f619 100644 --- a/tests/common/test_data_writers.py +++ b/tests/common/test_data_writers.py @@ -1,51 +1,55 @@ import io +import pytest +from typing import Iterator from dlt.common import pendulum -from dlt.common.dataset_writers import write_insert_values, escape_redshift_literal, escape_redshift_identifier, escape_bigquery_identifier +from dlt.common.data_writers import escape_redshift_literal, escape_redshift_identifier, escape_bigquery_identifier +from dlt.common.data_writers.writers import DataWriter, InsertValuesWriter -from tests.common.utils import load_json_case +from tests.common.utils import load_json_case, row_to_column_schemas -def test_simple_insert_writer() -> None: - rows = load_json_case("simple_row") +@pytest.fixture +def insert_writer() -> Iterator[DataWriter]: with io.StringIO() as f: - write_insert_values(f, rows, rows[0].keys()) - lines = f.getvalue().split("\n") + yield InsertValuesWriter(f) + + +def test_simple_insert_writer(insert_writer: DataWriter) -> None: + rows = load_json_case("simple_row") + insert_writer.write_all(row_to_column_schemas(rows[0]), rows) + lines = insert_writer._f.getvalue().split("\n") assert lines[0].startswith("INSERT INTO {}") assert '","'.join(rows[0].keys()) in lines[0] assert lines[1] == "VALUES" assert len(lines) == 4 -def test_bytes_insert_writer() -> None: +def test_bytes_insert_writer(insert_writer: DataWriter) -> None: rows = [{"bytes": b"bytes"}] - with io.StringIO() as f: - write_insert_values(f, rows, rows[0].keys()) - lines = f.getvalue().split("\n") + insert_writer.write_all(row_to_column_schemas(rows[0]), rows) + lines = insert_writer._f.getvalue().split("\n") assert lines[2] == "(from_hex('6279746573'));" -def test_datetime_insert_writer() -> None: +def test_datetime_insert_writer(insert_writer: DataWriter) -> None: rows = [{"datetime": pendulum.from_timestamp(1658928602.575267)}] - with io.StringIO() as f: - write_insert_values(f, rows, rows[0].keys()) - lines = f.getvalue().split("\n") + insert_writer.write_all(row_to_column_schemas(rows[0]), rows) + lines = insert_writer._f.getvalue().split("\n") assert lines[2] == "('2022-07-27T13:30:02.575267+00:00');" -def test_date_insert_writer() -> None: +def test_date_insert_writer(insert_writer: DataWriter) -> None: rows = [{"date": pendulum.date(1974, 8, 11)}] - with io.StringIO() as f: - write_insert_values(f, rows, rows[0].keys()) - lines = f.getvalue().split("\n") + insert_writer.write_all(row_to_column_schemas(rows[0]), rows) + lines = insert_writer._f.getvalue().split("\n") assert lines[2] == "('1974-08-11');" -def test_unicode_insert_writer() -> None: +def test_unicode_insert_writer(insert_writer: DataWriter) -> None: rows = load_json_case("weird_rows") - with io.StringIO() as f: - write_insert_values(f, rows, rows[0].keys()) - lines = f.getvalue().split("\n") + insert_writer.write_all(row_to_column_schemas(rows[0]), rows) + lines = insert_writer._f.getvalue().split("\n") assert lines[2].endswith("', NULL''); DROP SCHEMA Public --'),") assert lines[3].endswith("'イロハニホヘト チリヌルヲ ''ワカヨタレソ ツネナラム'),") assert lines[4].endswith("'ऄअआइ''ईउऊऋऌऍऎए'),") diff --git a/tests/common/test_logging.py b/tests/common/test_logging.py index ef3ef21040..932b6876a2 100644 --- a/tests/common/test_logging.py +++ b/tests/common/test_logging.py @@ -2,33 +2,35 @@ import logging import json_logging from os import environ +from importlib.metadata import version as pkg_version -from dlt import __version__ as auto_version +from dlt import __version__ as code_version from dlt.common import logger, sleep from dlt.common.typing import StrStr -from dlt.common.configuration import RunConfiguration +from dlt.common.configuration import configspec +from dlt.common.configuration.specs import RunConfiguration from tests.utils import preserve_environ +@configspec class PureBasicConfiguration(RunConfiguration): - PIPELINE_NAME: str = "logger" + pipeline_name: str = "logger" -class PureBasicConfigurationProc(PureBasicConfiguration): - _VERSION: str = "1.6.6" - - -class JsonLoggerConfiguration(PureBasicConfigurationProc): - LOG_FORMAT: str = "JSON" +@configspec +class JsonLoggerConfiguration(PureBasicConfiguration): + log_format: str = "JSON" +@configspec class SentryLoggerConfiguration(JsonLoggerConfiguration): - SENTRY_DSN: str = "http://user:pass@localhost/818782" + sentry_dsn: str = "http://user:pass@localhost/818782" +@configspec(init=True) class SentryLoggerCriticalConfiguration(SentryLoggerConfiguration): - LOG_LEVEL: str = "CRITICAL" + log_level: str = "CRITICAL" @pytest.fixture(scope="function") @@ -39,15 +41,14 @@ def environment() -> StrStr: def test_version_extract(environment: StrStr) -> None: - version = logger._extract_version_info(PureBasicConfiguration) - # if component ver not avail use system version - assert version == {'version': auto_version, 'component_name': 'logger'} - version = logger._extract_version_info(PureBasicConfigurationProc) - assert version["component_version"] == PureBasicConfigurationProc._VERSION + version = logger._extract_version_info(PureBasicConfiguration()) + assert version["dlt_version"].startswith(code_version) + lib_version = pkg_version("python-dlt") + assert version == {'dlt_version': lib_version, 'pipeline_name': 'logger'} # mock image info available in container _mock_image_env(environment) - version = logger._extract_version_info(PureBasicConfigurationProc) - assert version == {'version': auto_version, 'commit_sha': '192891', 'component_name': 'logger', 'component_version': '1.6.6', 'image_version': 'scale/v:112'} + version = logger._extract_version_info(PureBasicConfiguration()) + assert version == {'dlt_version': lib_version, 'commit_sha': '192891', 'pipeline_name': 'logger', 'image_version': 'scale/v:112'} def test_pod_info_extract(environment: StrStr) -> None: @@ -62,7 +63,7 @@ def test_pod_info_extract(environment: StrStr) -> None: def test_text_logger_init(environment: StrStr) -> None: _mock_image_env(environment) _mock_pod_env(environment) - logger.init_logging_from_config(PureBasicConfigurationProc) + logger.init_logging_from_config(PureBasicConfiguration()) logger.health("HEALTH data", extra={"metrics": "props"}) logger.metrics("METRICS data", extra={"metrics": "props"}) logger.warning("Warning message here") @@ -89,17 +90,13 @@ def test_json_logger_init(environment: StrStr) -> None: def test_sentry_log_level() -> None: - SentryLoggerCriticalConfiguration.LOG_LEVEL = "CRITICAL" - sll = logger._get_sentry_log_level(SentryLoggerCriticalConfiguration) + sll = logger._get_sentry_log_level(SentryLoggerCriticalConfiguration(log_level="CRITICAL")) assert sll._handler.level == logging._nameToLevel["CRITICAL"] - SentryLoggerCriticalConfiguration.LOG_LEVEL = "ERROR" - sll = logger._get_sentry_log_level(SentryLoggerCriticalConfiguration) + sll = logger._get_sentry_log_level(SentryLoggerCriticalConfiguration(log_level="ERROR")) assert sll._handler.level == logging._nameToLevel["ERROR"] - SentryLoggerCriticalConfiguration.LOG_LEVEL = "WARNING" - sll = logger._get_sentry_log_level(SentryLoggerCriticalConfiguration) + sll = logger._get_sentry_log_level(SentryLoggerCriticalConfiguration(log_level="WARNING")) assert sll._handler.level == logging._nameToLevel["WARNING"] - SentryLoggerCriticalConfiguration.LOG_LEVEL = "INFO" - sll = logger._get_sentry_log_level(SentryLoggerCriticalConfiguration) + sll = logger._get_sentry_log_level(SentryLoggerCriticalConfiguration(log_level="INFO")) assert sll._handler.level == logging._nameToLevel["WARNING"] @@ -107,7 +104,7 @@ def test_sentry_log_level() -> None: def test_sentry_init(environment: StrStr) -> None: _mock_image_env(environment) _mock_pod_env(environment) - logger.init_logging_from_config(SentryLoggerConfiguration) + logger.init_logging_from_config(SentryLoggerConfiguration()) try: 1 / 0 except ZeroDivisionError: diff --git a/tests/common/test_typing.py b/tests/common/test_typing.py index c80461dd59..30da464a88 100644 --- a/tests/common/test_typing.py +++ b/tests/common/test_typing.py @@ -1,7 +1,7 @@ -from typing import List, Literal, Mapping, MutableMapping, MutableSequence, Sequence, TypedDict, Optional +from typing import List, Literal, Mapping, MutableMapping, MutableSequence, NewType, Sequence, TypeVar, TypedDict, Optional -from dlt.common.typing import extract_optional_type, is_dict_generic_type, is_list_generic_type, is_literal_type, is_optional_type, is_typeddict +from dlt.common.typing import extract_inner_type, extract_optional_type, is_dict_generic_type, is_list_generic_type, is_literal_type, is_newtype_type, is_optional_type, is_typeddict @@ -20,7 +20,7 @@ def test_is_typeddict() -> None: assert is_typeddict(Sequence[str]) is False -def test_is_list_type() -> None: +def test_is_list_generic_type() -> None: # yes - we need a generic type assert is_list_generic_type(list) is False assert is_list_generic_type(List[str]) is True @@ -28,13 +28,13 @@ def test_is_list_type() -> None: assert is_list_generic_type(MutableSequence[str]) is True -def test_is_dict_type() -> None: +def test_is_dict_generic_type() -> None: assert is_dict_generic_type(dict) is False assert is_dict_generic_type(Mapping[str, str]) is True assert is_dict_generic_type(MutableMapping[str, str]) is True -def test_literal() -> None: +def test_is_literal() -> None: assert is_literal_type(TTestLi) is True assert is_literal_type("a") is False assert is_literal_type(List[str]) is False @@ -46,3 +46,22 @@ def test_optional() -> None: assert is_optional_type(TTestTyDi) is False assert extract_optional_type(TOptionalLi) is TTestLi assert extract_optional_type(TOptionalTyDi) is TTestTyDi + + +def test_is_newtype() -> None: + assert is_newtype_type(NewType("NT1", str)) is True + assert is_newtype_type(TypeVar("TV1", bound=str)) is False + assert is_newtype_type(1) is False + + +def test_extract_inner_type() -> None: + assert extract_inner_type(1) == 1 + assert extract_inner_type(str) is str + assert extract_inner_type(NewType("NT1", str)) is str + assert extract_inner_type(NewType("NT2", NewType("NT3", int))) is int + assert extract_inner_type(Optional[NewType("NT3", bool)]) is bool # noqa + l_1 = Literal[1, 2, 3] + assert extract_inner_type(l_1) is int + nt_l_2 = NewType("NTL2", float) + l_2 = Literal[nt_l_2(1.238), nt_l_2(2.343)] + assert extract_inner_type(l_2) is float diff --git a/tests/common/test_validation.py b/tests/common/test_validation.py index 019752c200..4bc516615b 100644 --- a/tests/common/test_validation.py +++ b/tests/common/test_validation.py @@ -88,7 +88,7 @@ def test_validate_schema_cases() -> None: validate_dict(TStoredSchema, schema_dict, ".", lambda k: not k.startswith("x-"), simple_regex_validator) - # with open("tests/common/cases/schemas/rasa/event_schema.json") as f: + # with open("tests/common/cases/schemas/rasa/event.schema.json") as f: # schema_dict: TStoredSchema = json.load(f) # validate_dict(TStoredSchema, schema_dict, ".", lambda k: not k.startswith("x-")) diff --git a/tests/common/utils.py b/tests/common/utils.py index e85b4649f3..ef423214ae 100644 --- a/tests/common/utils.py +++ b/tests/common/utils.py @@ -2,6 +2,9 @@ from typing import Mapping, cast from dlt.common import json +from dlt.common.typing import StrAny +from dlt.common.schema import utils +from dlt.common.schema.typing import TTableSchemaColumns def load_json_case(name: str) -> Mapping: @@ -20,3 +23,11 @@ def json_case_path(name: str) -> str: def yml_case_path(name: str) -> str: return f"./tests/common/cases/{name}.yml" + + +def row_to_column_schemas(row: StrAny) -> TTableSchemaColumns: + return {k: utils.add_missing_hints({ + "name": k, + "data_type": "text", + "nullable": False + }) for k in row.keys()} diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000..a7f696b7d2 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,29 @@ +import os +import dataclasses + +def pytest_configure(config): + # patch the configurations to use test storage by default, we modify the types (classes) fields + # the dataclass implementation will use those patched values when creating instances (the values present + # in the declaration are not frozen allowing patching) + + from dlt.common.configuration.specs import normalize_volume_configuration, run_configuration, load_volume_configuration, schema_volume_configuration + + test_storage_root = "_storage" + run_configuration.RunConfiguration.config_files_storage_path = os.path.join(test_storage_root, "config/%s") + + load_volume_configuration.LoadVolumeConfiguration.load_volume_path = os.path.join(test_storage_root, "load") + delattr(load_volume_configuration.LoadVolumeConfiguration, "__init__") + load_volume_configuration.LoadVolumeConfiguration = dataclasses.dataclass(load_volume_configuration.LoadVolumeConfiguration, init=True, repr=False) + + normalize_volume_configuration.NormalizeVolumeConfiguration.normalize_volume_path = os.path.join(test_storage_root, "normalize") + # delete __init__, otherwise it will not be recreated by dataclass + delattr(normalize_volume_configuration.NormalizeVolumeConfiguration, "__init__") + normalize_volume_configuration.NormalizeVolumeConfiguration = dataclasses.dataclass(normalize_volume_configuration.NormalizeVolumeConfiguration, init=True, repr=False) + + schema_volume_configuration.SchemaVolumeConfiguration.schema_volume_path = os.path.join(test_storage_root, "schemas") + delattr(schema_volume_configuration.SchemaVolumeConfiguration, "__init__") + schema_volume_configuration.SchemaVolumeConfiguration = dataclasses.dataclass(schema_volume_configuration.SchemaVolumeConfiguration, init=True, repr=False) + + + assert run_configuration.RunConfiguration.config_files_storage_path == os.path.join(test_storage_root, "config/%s") + assert run_configuration.RunConfiguration().config_files_storage_path == os.path.join(test_storage_root, "config/%s") diff --git a/tests/dbt_runner/test_runner_bigquery.py b/tests/dbt_runner/test_runner_bigquery.py index 24756c3575..c11e9f992e 100644 --- a/tests/dbt_runner/test_runner_bigquery.py +++ b/tests/dbt_runner/test_runner_bigquery.py @@ -2,14 +2,14 @@ import pytest from dlt.common import logger -from dlt.common.configuration import GcpClientCredentials +from dlt.common.configuration.specs import GcpClientCredentials from dlt.common.telemetry import TRunMetrics, get_metrics_from_prometheus from dlt.common.typing import StrStr -from dlt.common.utils import uniq_id, with_custom_environ +from dlt.common.utils import uniq_id from dlt.dbt_runner.utils import DBTProcessingError from dlt.dbt_runner import runner -from dlt.load.bigquery.client import BigQuerySqlClient +from dlt.load.bigquery.bigquery import BigQuerySqlClient from tests.utils import add_config_to_env, init_logger, preserve_environ from tests.dbt_runner.utils import setup_runner @@ -57,8 +57,8 @@ def test_create_folders() -> None: setup_runner("eks_dev_dest", override_values={ "SOURCE_SCHEMA_PREFIX": "carbon_bot_3", "PACKAGE_ADDITIONAL_VARS": {"add_var_name": "add_var_value"}, - "LOG_FORMAT": "JSON", - "LOG_LEVEL": "INFO" + "log_format": "JSON", + "log_level": "INFO" }) assert runner.repo_path.endswith(runner.CLONED_PACKAGE_NAME) diff --git a/tests/dbt_runner/test_runner_redshift.py b/tests/dbt_runner/test_runner_redshift.py index a7cc1ac8b6..0c915ea709 100644 --- a/tests/dbt_runner/test_runner_redshift.py +++ b/tests/dbt_runner/test_runner_redshift.py @@ -4,9 +4,9 @@ from prometheus_client import CollectorRegistry from dlt.common import logger -from dlt.common.configuration import PostgresCredentials -from dlt.common.configuration.utils import make_configuration -from dlt.common.file_storage import FileStorage +from dlt.common.configuration import resolve_configuration +from dlt.common.configuration.specs import PostgresCredentials +from dlt.common.storages import FileStorage from dlt.common.telemetry import TRunMetrics, get_metrics_from_prometheus from dlt.common.typing import StrStr from dlt.common.utils import uniq_id, with_custom_environ @@ -14,9 +14,9 @@ from dlt.dbt_runner.utils import DBTProcessingError from dlt.dbt_runner.configuration import DBTRunnerConfiguration from dlt.dbt_runner import runner -from dlt.load.redshift.client import RedshiftSqlClient +from dlt.load.redshift.redshift import RedshiftSqlClient -from tests.utils import add_config_to_env, clean_storage, init_logger, preserve_environ +from tests.utils import add_config_to_env, clean_test_storage, init_logger, preserve_environ from tests.dbt_runner.utils import modify_and_commit_file, load_secret, setup_runner DEST_SCHEMA_PREFIX = "test_" + uniq_id() @@ -25,7 +25,7 @@ @pytest.fixture(scope="module", autouse=True) def module_autouse() -> None: # disable GCP in environ - del environ["GCP__PROJECT_ID"] + del environ["CREDENTIALS__PROJECT_ID"] # set the test case for the unit tests environ["DEFAULT_DATASET"] = "test_fixture_carbon_bot_session_cases" add_config_to_env(PostgresCredentials) @@ -61,27 +61,25 @@ def module_autouse() -> None: def test_configuration() -> None: # check names normalized - C = make_configuration( - DBTRunnerConfiguration, - DBTRunnerConfiguration, - initial_values={"PACKAGE_REPOSITORY_SSH_KEY": "---NO NEWLINE---", "SOURCE_SCHEMA_PREFIX": "schema"} + C = resolve_configuration( + DBTRunnerConfiguration(), + initial_value={"PACKAGE_REPOSITORY_SSH_KEY": "---NO NEWLINE---", "SOURCE_SCHEMA_PREFIX": "schema"} ) - assert C.PACKAGE_REPOSITORY_SSH_KEY == "---NO NEWLINE---\n" + assert C.package_repository_ssh_key == "---NO NEWLINE---\n" - C = make_configuration( - DBTRunnerConfiguration, - DBTRunnerConfiguration, - initial_values={"PACKAGE_REPOSITORY_SSH_KEY": "---WITH NEWLINE---\n", "SOURCE_SCHEMA_PREFIX": "schema"} + C = resolve_configuration( + DBTRunnerConfiguration(), + initial_value={"PACKAGE_REPOSITORY_SSH_KEY": "---WITH NEWLINE---\n", "SOURCE_SCHEMA_PREFIX": "schema"} ) - assert C.PACKAGE_REPOSITORY_SSH_KEY == "---WITH NEWLINE---\n" + assert C.package_repository_ssh_key == "---WITH NEWLINE---\n" def test_create_folders() -> None: setup_runner("eks_dev_dest", override_values={ "SOURCE_SCHEMA_PREFIX": "carbon_bot_3", "PACKAGE_ADDITIONAL_VARS": {"add_var_name": "add_var_value"}, - "LOG_FORMAT": "JSON", - "LOG_LEVEL": "INFO" + "log_format": "JSON", + "log_level": "INFO" }) assert runner.repo_path.endswith(runner.CLONED_PACKAGE_NAME) assert runner.profile_name == "rasa_semantic_schema_redshift" @@ -94,7 +92,7 @@ def test_initialize_package_wrong_key() -> None: # private repo "PACKAGE_REPOSITORY_URL": "git@github.com:scale-vector/rasa_bot_experiments.git" }) - runner.CONFIG.PACKAGE_REPOSITORY_SSH_KEY = load_secret("DEPLOY_KEY") + runner.CONFIG.package_repository_ssh_key = load_secret("DEPLOY_KEY") with pytest.raises(GitCommandError): runner.run(None) @@ -104,12 +102,12 @@ def test_reinitialize_package() -> None: setup_runner(DEST_SCHEMA_PREFIX) runner.ensure_newest_package() # mod the package - readme_path = modify_and_commit_file(runner.repo_path, "README.md", content=runner.CONFIG.DEST_SCHEMA_PREFIX) + readme_path = modify_and_commit_file(runner.repo_path, "README.md", content=runner.CONFIG.dest_schema_prefix) assert runner.storage.has_file(readme_path) # this will wipe out old package and clone again runner.ensure_newest_package() # we have old file back - assert runner.storage.load(f"{runner.CLONED_PACKAGE_NAME}/README.md") != runner.CONFIG.DEST_SCHEMA_PREFIX + assert runner.storage.load(f"{runner.CLONED_PACKAGE_NAME}/README.md") != runner.CONFIG.dest_schema_prefix def test_dbt_test_no_raw_schema() -> None: @@ -178,7 +176,7 @@ def test_dbt_incremental_schema_out_of_sync_error() -> None: def get_runner() -> FileStorage: - clean_storage() + clean_test_storage() runner.storage, runner.dbt_package_vars, runner.global_args, runner.repo_path, runner.profile_name = runner.create_folders() runner.model_elapsed_gauge, runner.model_exec_info = runner.create_gauges(CollectorRegistry(auto_describe=True)) return runner.storage diff --git a/tests/dbt_runner/test_utils.py b/tests/dbt_runner/test_utils.py index 1e849246d6..162d5fd20a 100644 --- a/tests/dbt_runner/test_utils.py +++ b/tests/dbt_runner/test_utils.py @@ -3,11 +3,11 @@ from git import GitCommandError, Repo, RepositoryDirtyError import pytest -from dlt.common.file_storage import FileStorage +from dlt.common.storages import FileStorage from dlt.dbt_runner.utils import DBTProcessingError, clone_repo, ensure_remote_head, git_custom_key_command, initialize_dbt_logging, run_dbt_command -from tests.utils import root_storage +from tests.utils import test_storage from tests.dbt_runner.utils import load_secret, modify_and_commit_file, restore_secret_storage_path @@ -32,60 +32,60 @@ def test_no_ssh_key_context() -> None: assert git_command == 'ssh -o "StrictHostKeyChecking accept-new"' -def test_clone(root_storage: FileStorage) -> None: - repo_path = root_storage._make_path("awesome_repo") +def test_clone(test_storage: FileStorage) -> None: + repo_path = test_storage.make_full_path("awesome_repo") # clone a small public repo clone_repo(AWESOME_REPO, repo_path, with_git_command=None) - assert root_storage.has_folder("awesome_repo") + assert test_storage.has_folder("awesome_repo") # make sure directory clean ensure_remote_head(repo_path, with_git_command=None) -def test_clone_with_commit_id(root_storage: FileStorage) -> None: - repo_path = root_storage._make_path("awesome_repo") +def test_clone_with_commit_id(test_storage: FileStorage) -> None: + repo_path = test_storage.make_full_path("awesome_repo") # clone a small public repo clone_repo(AWESOME_REPO, repo_path, with_git_command=None, branch="7f88000be2d4f265c83465fec4b0b3613af347dd") - assert root_storage.has_folder("awesome_repo") + assert test_storage.has_folder("awesome_repo") ensure_remote_head(repo_path, with_git_command=None) -def test_clone_with_wrong_branch(root_storage: FileStorage) -> None: - repo_path = root_storage._make_path("awesome_repo") +def test_clone_with_wrong_branch(test_storage: FileStorage) -> None: + repo_path = test_storage.make_full_path("awesome_repo") # clone a small public repo with pytest.raises(GitCommandError): clone_repo(AWESOME_REPO, repo_path, with_git_command=None, branch="wrong_branch") -def test_clone_with_deploy_key_access_denied(root_storage: FileStorage) -> None: +def test_clone_with_deploy_key_access_denied(test_storage: FileStorage) -> None: secret = load_secret("deploy_key") - repo_path = root_storage._make_path("private_repo") + repo_path = test_storage.make_full_path("private_repo") with git_custom_key_command(secret) as git_command: with pytest.raises(GitCommandError): clone_repo(PRIVATE_REPO, repo_path, with_git_command=git_command) -def test_clone_with_deploy_key(root_storage: FileStorage) -> None: +def test_clone_with_deploy_key(test_storage: FileStorage) -> None: secret = load_secret("deploy_key") - repo_path = root_storage._make_path("private_repo_access") + repo_path = test_storage.make_full_path("private_repo_access") with git_custom_key_command(secret) as git_command: clone_repo(PRIVATE_REPO_WITH_ACCESS, repo_path, with_git_command=git_command) ensure_remote_head(repo_path, with_git_command=git_command) -def test_repo_status_update(root_storage: FileStorage) -> None: +def test_repo_status_update(test_storage: FileStorage) -> None: secret = load_secret("deploy_key") - repo_path = root_storage._make_path("private_repo_access") + repo_path = test_storage.make_full_path("private_repo_access") with git_custom_key_command(secret) as git_command: clone_repo(PRIVATE_REPO_WITH_ACCESS, repo_path, with_git_command=git_command) # modify README.md readme_path = modify_and_commit_file(repo_path, "README.md") - assert root_storage.has_file(readme_path) + assert test_storage.has_file(readme_path) with pytest.raises(RepositoryDirtyError): ensure_remote_head(repo_path, with_git_command=git_command) -def test_dbt_commands(root_storage: FileStorage) -> None: - repo_path = root_storage._make_path("jaffle_shop") +def test_dbt_commands(test_storage: FileStorage) -> None: + repo_path = test_storage.make_full_path("jaffle_shop") # clone jaffle shop for dbt 1.0.0 clone_repo(JAFFLE_SHOP_REPO, repo_path, with_git_command=None, branch="core-v1.0.0") # copy profile diff --git a/tests/dbt_runner/utils.py b/tests/dbt_runner/utils.py index 2484fbdb9b..baf762ed28 100644 --- a/tests/dbt_runner/utils.py +++ b/tests/dbt_runner/utils.py @@ -8,7 +8,7 @@ from dlt.dbt_runner.configuration import gen_configuration_variant from dlt.dbt_runner import runner -from tests.utils import clean_storage +from tests.utils import clean_test_storage SECRET_STORAGE_PATH = environ.SECRET_STORAGE_PATH @@ -21,7 +21,7 @@ def restore_secret_storage_path() -> None: def load_secret(name: str) -> str: environ.SECRET_STORAGE_PATH = "./tests/dbt_runner/secrets/%s" - secret = environ._get_key_value(name, environ.TSecretValue) + secret = environ.get_key(name, environ.TSecretValue) if not secret: raise FileNotFoundError(environ.SECRET_STORAGE_PATH % name) return secret @@ -45,11 +45,11 @@ def modify_and_commit_file(repo_path: str, file_name: str, content: str = "NEW R def setup_runner(dest_schema_prefix: str, override_values: StrAny = None) -> None: - clean_storage() + clean_test_storage() C = gen_configuration_variant(initial_values=override_values) # set unique dest schema prefix by default - C.DEST_SCHEMA_PREFIX = dest_schema_prefix - C.PACKAGE_RUN_PARAMS = ["--fail-fast", "--full-refresh"] + C.dest_schema_prefix = dest_schema_prefix + C.package_run_params = ["--fail-fast", "--full-refresh"] # override values including the defaults above if override_values: for k,v in override_values.items(): diff --git a/tests/load/bigquery/test_bigquery_client.py b/tests/load/bigquery/test_bigquery_client.py index e22ea64107..4e25331bce 100644 --- a/tests/load/bigquery/test_bigquery_client.py +++ b/tests/load/bigquery/test_bigquery_client.py @@ -4,16 +4,16 @@ from dlt.common import json, pendulum, Decimal from dlt.common.arithmetics import numeric_default_context -from dlt.common.file_storage import FileStorage +from dlt.common.storages import FileStorage from dlt.common.schema.schema import Schema from dlt.common.utils import uniq_id from dlt.load.exceptions import LoadJobNotExistsException, LoadJobServerTerminalException from dlt.load import Load -from dlt.load.bigquery.client import BigQueryClient +from dlt.load.bigquery.bigquery import BigQueryClient -from tests.utils import TEST_STORAGE, delete_storage -from tests.load.utils import cm_yield_client_with_storage, expect_load_file, prepare_table, yield_client_with_storage +from tests.utils import TEST_STORAGE_ROOT, delete_test_storage +from tests.load.utils import expect_load_file, prepare_table, yield_client_with_storage, cm_yield_client_with_storage @pytest.fixture(scope="module") @@ -23,12 +23,12 @@ def client() -> Iterator[BigQueryClient]: @pytest.fixture def file_storage() -> FileStorage: - return FileStorage(TEST_STORAGE, file_type="b", makedirs=True) + return FileStorage(TEST_STORAGE_ROOT, file_type="b", makedirs=True) @pytest.fixture(autouse=True) def auto_delete_storage() -> None: - delete_storage() + delete_test_storage() def test_bigquery_job_errors(client: BigQueryClient, file_storage: FileStorage) -> None: @@ -61,13 +61,13 @@ def test_bigquery_job_errors(client: BigQueryClient, file_storage: FileStorage) job = expect_load_file(client, file_storage, json.dumps(load_json), user_table_name) # start a job from the same file. it should fallback to retrieve job silently - r_job = client.start_file_load(client.schema.get_table(user_table_name), file_storage._make_path(job.file_name())) + r_job = client.start_file_load(client.schema.get_table(user_table_name), file_storage.make_full_path(job.file_name())) assert r_job.status() == "completed" @pytest.mark.parametrize('location', ["US", "EU"]) def test_bigquery_location(location: str, file_storage: FileStorage) -> None: - with cm_yield_client_with_storage("bigquery", initial_values={"LOCATION": location}) as client: + with cm_yield_client_with_storage("bigquery", initial_values={"location": location}) as client: user_table_name = prepare_table(client) load_json = { "_dlt_id": uniq_id(), @@ -78,7 +78,7 @@ def test_bigquery_location(location: str, file_storage: FileStorage) -> None: job = expect_load_file(client, file_storage, json.dumps(load_json), user_table_name) # start a job from the same file. it should fallback to retrieve job silently - client.start_file_load(client.schema.get_table(user_table_name), file_storage._make_path(job.file_name())) + client.start_file_load(client.schema.get_table(user_table_name), file_storage.make_full_path(job.file_name())) canonical_name = client.sql_client.make_qualified_table_name(user_table_name) t = client.sql_client.native_connection.get_table(canonical_name) assert t.location == location diff --git a/tests/load/bigquery/test_bigquery_table_builder.py b/tests/load/bigquery/test_bigquery_table_builder.py index 5d3fdacf57..a8ffe2b02b 100644 --- a/tests/load/bigquery/test_bigquery_table_builder.py +++ b/tests/load/bigquery/test_bigquery_table_builder.py @@ -4,9 +4,11 @@ from dlt.common.utils import custom_environ, uniq_id from dlt.common.schema import Schema from dlt.common.schema.utils import new_table -from dlt.common.configuration import make_configuration, GcpClientCredentials +from dlt.common.configuration import resolve_configuration +from dlt.common.configuration.specs import GcpClientCredentials -from dlt.load.bigquery.client import BigQueryClient +from dlt.load.bigquery.bigquery import BigQueryClient +from dlt.load.bigquery.configuration import BigQueryClientConfiguration from dlt.load.exceptions import LoadClientSchemaWillNotUpdate from tests.load.utils import TABLE_UPDATE @@ -19,20 +21,19 @@ def schema() -> Schema: def test_configuration() -> None: # check names normalized - with custom_environ({"GCP__PRIVATE_KEY": "---NO NEWLINE---\n"}): - C = make_configuration(GcpClientCredentials, GcpClientCredentials) - assert C.PRIVATE_KEY == "---NO NEWLINE---\n" + with custom_environ({"CREDENTIALS__PRIVATE_KEY": "---NO NEWLINE---\n"}): + C = resolve_configuration(GcpClientCredentials()) + assert C.private_key == "---NO NEWLINE---\n" - with custom_environ({"GCP__PRIVATE_KEY": "---WITH NEWLINE---\n"}): - C = make_configuration(GcpClientCredentials, GcpClientCredentials) - assert C.PRIVATE_KEY == "---WITH NEWLINE---\n" + with custom_environ({"CREDENTIALS__PRIVATE_KEY": "---WITH NEWLINE---\n"}): + C = resolve_configuration(GcpClientCredentials()) + assert C.private_key == "---WITH NEWLINE---\n" @pytest.fixture def gcp_client(schema: Schema) -> BigQueryClient: # return client without opening connection - BigQueryClient.configure(initial_values={"DEFAULT_DATASET": uniq_id()}) - return BigQueryClient(schema) + return BigQueryClient(schema, BigQueryClientConfiguration(dataset_name="TEST" + uniq_id(), credentials=GcpClientCredentials())) def test_create_table(gcp_client: BigQueryClient) -> None: diff --git a/tests/load/cases/event_schema.json b/tests/load/cases/event.schema.json similarity index 100% rename from tests/load/cases/event_schema.json rename to tests/load/cases/event.schema.json diff --git a/tests/load/redshift/test_pipelines.py b/tests/load/redshift/test_pipelines.py index 77d16c4932..30e2936708 100644 --- a/tests/load/redshift/test_pipelines.py +++ b/tests/load/redshift/test_pipelines.py @@ -1,66 +1,66 @@ -import os -import pytest -from os import environ +# import os +# import pytest +# from os import environ -from dlt.common.schema.schema import Schema -from dlt.common.utils import uniq_id -from dlt.pipeline import Pipeline, PostgresPipelineCredentials -from dlt.pipeline.exceptions import InvalidPipelineContextException +# from dlt.common.schema.schema import Schema +# from dlt.common.utils import uniq_id +# from dlt.pipeline import Pipeline, PostgresPipelineCredentials +# from dlt.pipeline.exceptions import InvalidPipelineContextException -from tests.utils import autouse_root_storage, TEST_STORAGE +# from tests.utils import autouse_test_storage, TEST_STORAGE_ROOT -FAKE_CREDENTIALS = PostgresPipelineCredentials("redshift", None, None, None, None) +# FAKE_CREDENTIALS = PostgresPipelineCredentials("redshift", None, None, None, None) -def test_empty_default_schema_name() -> None: - p = Pipeline("test_empty_default_schema_name") - FAKE_CREDENTIALS.DEFAULT_DATASET = environ["DEFAULT_DATASET"] = "test_empty_default_schema_name" + uniq_id() - p.create_pipeline(FAKE_CREDENTIALS, os.path.join(TEST_STORAGE, FAKE_CREDENTIALS.DEFAULT_DATASET), Schema("default")) - p.extract(iter(["a", "b", "c"]), table_name="test") - p.normalize() - p.load() +# def test_empty_default_schema_name() -> None: +# p = Pipeline("test_empty_default_schema_name") +# FAKE_CREDENTIALS.DEFAULT_DATASET = environ["DEFAULT_DATASET"] = "test_empty_default_schema_name" + uniq_id() +# p.create_pipeline(FAKE_CREDENTIALS, os.path.join(TEST_STORAGE_ROOT, FAKE_CREDENTIALS.DEFAULT_DATASET), Schema("default")) +# p.extract(iter(["a", "b", "c"]), table_name="test") +# p.normalize() +# p.load() - # delete data - with p.sql_client() as c: - c.drop_dataset() +# # delete data +# with p.sql_client() as c: +# c.drop_dataset() - # try to restore pipeline - r_p = Pipeline("test_empty_default_schema_name") - r_p.restore_pipeline(FAKE_CREDENTIALS, p.working_dir) - schema = r_p.get_default_schema() - assert schema.name == "default" +# # try to restore pipeline +# r_p = Pipeline("test_empty_default_schema_name") +# r_p.restore_pipeline(FAKE_CREDENTIALS, p.working_dir) +# schema = r_p.get_default_schema() +# assert schema.name == "default" -def test_create_wipes_working_dir() -> None: - p = Pipeline("test_create_wipes_working_dir") - FAKE_CREDENTIALS.DEFAULT_DATASET = environ["DEFAULT_DATASET"] = "test_create_wipes_working_dir" + uniq_id() - p.create_pipeline(FAKE_CREDENTIALS, working_dir=os.path.join(TEST_STORAGE, FAKE_CREDENTIALS.DEFAULT_DATASET), schema=Schema("table")) - p.extract(iter(["a", "b", "c"]), table_name="test") - p.normalize() - assert len(p.list_normalized_loads()) > 0 +# def test_create_wipes_working_dir() -> None: +# p = Pipeline("test_create_wipes_working_dir") +# FAKE_CREDENTIALS.DEFAULT_DATASET = environ["DEFAULT_DATASET"] = "test_create_wipes_working_dir" + uniq_id() +# p.create_pipeline(FAKE_CREDENTIALS, working_dir=os.path.join(TEST_STORAGE_ROOT, FAKE_CREDENTIALS.DEFAULT_DATASET), schema=Schema("table")) +# p.extract(iter(["a", "b", "c"]), table_name="test") +# p.normalize() +# assert len(p.list_normalized_loads()) > 0 - # try to restore pipeline - r_p = Pipeline("test_create_wipes_working_dir") - r_p.restore_pipeline(FAKE_CREDENTIALS, p.working_dir) - assert len(r_p.list_normalized_loads()) > 0 - schema = r_p.get_default_schema() - assert schema.name == "table" +# # try to restore pipeline +# r_p = Pipeline("test_create_wipes_working_dir") +# r_p.restore_pipeline(FAKE_CREDENTIALS, p.working_dir) +# assert len(r_p.list_normalized_loads()) > 0 +# schema = r_p.get_default_schema() +# assert schema.name == "table" - # create pipeline in the same dir - p = Pipeline("overwrite_old") - # FAKE_CREDENTIALS.DEFAULT_DATASET = "new" - p.create_pipeline(FAKE_CREDENTIALS, working_dir=os.path.join(TEST_STORAGE, FAKE_CREDENTIALS.DEFAULT_DATASET), schema=Schema("matrix")) - assert len(p.list_normalized_loads()) == 0 +# # create pipeline in the same dir +# p = Pipeline("overwrite_old") +# # FAKE_CREDENTIALS.DEFAULT_DATASET = "new" +# p.create_pipeline(FAKE_CREDENTIALS, working_dir=os.path.join(TEST_STORAGE_ROOT, FAKE_CREDENTIALS.DEFAULT_DATASET), schema=Schema("matrix")) +# assert len(p.list_normalized_loads()) == 0 - # old pipeline is still functional but storage is wiped out - # TODO: but should be inactive - coming in API v2 - # with pytest.raises(InvalidPipelineContextException): - assert len(r_p.list_normalized_loads()) == 0 +# # old pipeline is still functional but storage is wiped out +# # TODO: but should be inactive - coming in API v2 +# # with pytest.raises(InvalidPipelineContextException): +# assert len(r_p.list_normalized_loads()) == 0 - # so recreate it - r_p = Pipeline("overwrite_old") - r_p.restore_pipeline(FAKE_CREDENTIALS, p.working_dir) - assert len(r_p.list_normalized_loads()) == 0 - schema = r_p.get_default_schema() - assert schema.name == "matrix" +# # so recreate it +# r_p = Pipeline("overwrite_old") +# r_p.restore_pipeline(FAKE_CREDENTIALS, p.working_dir) +# assert len(r_p.list_normalized_loads()) == 0 +# schema = r_p.get_default_schema() +# assert schema.name == "matrix" diff --git a/tests/load/redshift/test_redshift_client.py b/tests/load/redshift/test_redshift_client.py index c799bc2669..1416f63201 100644 --- a/tests/load/redshift/test_redshift_client.py +++ b/tests/load/redshift/test_redshift_client.py @@ -1,28 +1,29 @@ from typing import Iterator import pytest +from unittest.mock import patch from dlt.common import pendulum, Decimal from dlt.common.arithmetics import numeric_default_context -from dlt.common.file_storage import FileStorage +from dlt.common.storages import FileStorage from dlt.common.schema.schema import Schema from dlt.common.utils import uniq_id from dlt.load.exceptions import LoadClientTerminalInnerException from dlt.load import Load -from dlt.load.redshift.client import RedshiftClient, RedshiftInsertLoadJob, psycopg2 +from dlt.load.redshift.redshift import RedshiftClient, RedshiftInsertLoadJob, psycopg2 -from tests.utils import TEST_STORAGE, delete_storage, skipifpypy +from tests.utils import TEST_STORAGE_ROOT, delete_test_storage, skipifpypy from tests.load.utils import expect_load_file, prepare_table, yield_client_with_storage @pytest.fixture def file_storage() -> FileStorage: - return FileStorage(TEST_STORAGE, file_type="b", makedirs=True) + return FileStorage(TEST_STORAGE_ROOT, file_type="b", makedirs=True) @pytest.fixture(autouse=True) def auto_delete_storage() -> None: - delete_storage() + delete_test_storage() @pytest.fixture(scope="module") @@ -94,13 +95,15 @@ def test_long_names(client: RedshiftClient) -> None: @skipifpypy def test_loading_errors(client: RedshiftClient, file_storage: FileStorage) -> None: + caps = client.capabilities() + user_table_name = prepare_table(client) # insert string longer than redshift maximum insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\nVALUES\n" # try some unicode value - redshift checks the max length based on utf-8 representation, not the number of characters # max_len_str = 'उ' * (65535 // 3) + 1 -> does not fit # max_len_str = 'a' * 65535 + 1 -> does not fit - max_len_str = 'उ' * ((65535 // 3) + 1) + max_len_str = 'उ' * ((caps["max_text_data_type_length"] // 3) + 1) # max_len_str_b = max_len_str.encode("utf-8") # print(len(max_len_str_b)) row_id = uniq_id() @@ -157,10 +160,13 @@ def test_loading_errors(client: RedshiftClient, file_storage: FileStorage) -> No def test_query_split(client: RedshiftClient, file_storage: FileStorage) -> None: - max_statement_size = RedshiftInsertLoadJob.MAX_STATEMENT_SIZE - try: - # this guarantees that we execute inserts line by line - RedshiftInsertLoadJob.MAX_STATEMENT_SIZE = 1 + mocked_caps = RedshiftClient.capabilities() + # this guarantees that we execute inserts line by line + mocked_caps["max_query_length"] = 2 + + with patch.object(RedshiftClient, "capabilities") as caps: + caps.return_value = mocked_caps + print(RedshiftClient.capabilities()) user_table_name = prepare_table(client) insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\nVALUES\n" insert_values = "('{}', '{}', '90238094809sajlkjxoiewjhduuiuehd', '{}')" @@ -182,19 +188,23 @@ def test_query_split(client: RedshiftClient, file_storage: FileStorage) -> None: assert ids == v_ids - finally: - RedshiftInsertLoadJob.MAX_STATEMENT_SIZE = max_statement_size - @pytest.mark.skip -def test_maximum_statement(client: RedshiftClient, file_storage: FileStorage) -> None: - assert RedshiftInsertLoadJob.MAX_STATEMENT_SIZE == 20 * 1024 * 1024, "to enable this test, you must increase RedshiftInsertLoadJob.MAX_STATEMENT_SIZE = 20 * 1024 * 1024" - insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\nVALUES\n" - insert_values = "('{}', '{}', '90238094809sajlkjxoiewjhduuiuehd', '{}'){}" - insert_sql = insert_sql + insert_values.format(uniq_id(), uniq_id(), str(pendulum.now()), ",\n") * 150000 - insert_sql += insert_values.format(uniq_id(), uniq_id(), str(pendulum.now()), ";") +@skipifpypy +def test_maximum_query_size(client: RedshiftClient, file_storage: FileStorage) -> None: + mocked_caps = RedshiftClient.capabilities() + # this guarantees that we cross the redshift query limit + mocked_caps["max_query_length"] = 2 * 20 * 1024 * 1024 - user_table_name = prepare_table(client) - with pytest.raises(LoadClientTerminalInnerException) as exv: - expect_load_file(client, file_storage, insert_sql, user_table_name) - # psycopg2.errors.SyntaxError: Statement is too large. Statement Size: 20971754 bytes. Maximum Allowed: 16777216 bytes - assert type(exv.value.inner_exc) is psycopg2.ProgrammingError \ No newline at end of file + with patch.object(RedshiftClient, "capabilities") as caps: + caps.return_value = mocked_caps + + insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\nVALUES\n" + insert_values = "('{}', '{}', '90238094809sajlkjxoiewjhduuiuehd', '{}'){}" + insert_sql = insert_sql + insert_values.format(uniq_id(), uniq_id(), str(pendulum.now()), ",\n") * 150000 + insert_sql += insert_values.format(uniq_id(), uniq_id(), str(pendulum.now()), ";") + + user_table_name = prepare_table(client) + with pytest.raises(LoadClientTerminalInnerException) as exv: + expect_load_file(client, file_storage, insert_sql, user_table_name) + # psycopg2.errors.SyntaxError: Statement is too large. Statement Size: 20971754 bytes. Maximum Allowed: 16777216 bytes + assert type(exv.value.inner_exc) is psycopg2.errors.SyntaxError \ No newline at end of file diff --git a/tests/load/redshift/test_redshift_table_builder.py b/tests/load/redshift/test_redshift_table_builder.py index edcf00873b..d89ed5aa71 100644 --- a/tests/load/redshift/test_redshift_table_builder.py +++ b/tests/load/redshift/test_redshift_table_builder.py @@ -4,10 +4,12 @@ from dlt.common.utils import uniq_id, custom_environ from dlt.common.schema import Schema from dlt.common.schema.utils import new_table -from dlt.common.configuration import PostgresCredentials, make_configuration +from dlt.common.configuration import resolve_configuration +from dlt.common.configuration.specs import PostgresCredentials from dlt.load.exceptions import LoadClientSchemaWillNotUpdate -from dlt.load.redshift.client import RedshiftClient +from dlt.load.redshift.redshift import RedshiftClient +from dlt.load.redshift.configuration import RedshiftClientConfiguration from tests.load.utils import TABLE_UPDATE @@ -20,16 +22,15 @@ def schema() -> Schema: @pytest.fixture def client(schema: Schema) -> RedshiftClient: # return client without opening connection - RedshiftClient.configure(initial_values={"DEFAULT_DATASET": "TEST" + uniq_id()}) - return RedshiftClient(schema) + return RedshiftClient(schema, RedshiftClientConfiguration(dataset_name="TEST" + uniq_id())) def test_configuration() -> None: # check names normalized - with custom_environ({"PG__DBNAME": "UPPER_CASE_DATABASE", "PG__PASSWORD": " pass\n"}): - C = make_configuration(PostgresCredentials, PostgresCredentials) - assert C.DBNAME == "upper_case_database" - assert C.PASSWORD == "pass" + with custom_environ({"CREDENTIALS__DBNAME": "UPPER_CASE_DATABASE", "CREDENTIALS__PASSWORD": " pass\n"}): + C = resolve_configuration(PostgresCredentials()) + assert C.dbname == "upper_case_database" + assert C.password == "pass" def test_create_table(client: RedshiftClient) -> None: diff --git a/tests/load/test_client.py b/tests/load/test_client.py index 483b348749..b90c9b6f64 100644 --- a/tests/load/test_client.py +++ b/tests/load/test_client.py @@ -6,15 +6,16 @@ from dlt.common import json, pendulum from dlt.common.schema import Schema from dlt.common.schema.utils import new_table -from dlt.common.file_storage import FileStorage +from dlt.common.storages import FileStorage from dlt.common.schema import TTableSchemaColumns from dlt.common.utils import uniq_id -from dlt.load.client_base import DBCursor, SqlJobClientBase +from dlt.load.sql_client import DBCursor +from dlt.load.job_client_impl import SqlJobClientBase -from tests.utils import TEST_STORAGE, delete_storage +from tests.utils import TEST_STORAGE_ROOT, delete_test_storage from tests.common.utils import load_json_case -from tests.load.utils import TABLE_UPDATE, TABLE_ROW, expect_load_file, yield_client_with_storage, cm_yield_client_with_storage, write_dataset, prepare_table +from tests.load.utils import TABLE_UPDATE, TABLE_UPDATE_COLUMNS_SCHEMA, TABLE_ROW, expect_load_file, yield_client_with_storage, cm_yield_client_with_storage, write_dataset, prepare_table ALL_CLIENTS = ['redshift_client', 'bigquery_client'] @@ -23,12 +24,12 @@ @pytest.fixture def file_storage() -> FileStorage: - return FileStorage(TEST_STORAGE, file_type="b", makedirs=True) + return FileStorage(TEST_STORAGE_ROOT, file_type="b", makedirs=True) @pytest.fixture(autouse=True) def auto_delete_storage() -> None: - delete_storage() + delete_test_storage() @pytest.fixture(scope="module") @@ -211,7 +212,7 @@ def test_data_writer_load(client: SqlJobClientBase, file_storage: FileStorage) - canonical_name = client.sql_client.make_qualified_table_name(table_name) # write only first row with io.StringIO() as f: - write_dataset(client, f, [rows[0]], rows[0].keys()) + write_dataset(client, f, [rows[0]], client.schema.get_table(table_name)["columns"]) query = f.getvalue() expect_load_file(client, file_storage, query, table_name) db_row = client.sql_client.execute_sql(f"SELECT * FROM {canonical_name}")[0] @@ -219,7 +220,7 @@ def test_data_writer_load(client: SqlJobClientBase, file_storage: FileStorage) - assert list(db_row) == list(rows[0].values()) # write second row that contains two nulls with io.StringIO() as f: - write_dataset(client, f, [rows[1]], rows[0].keys()) + write_dataset(client, f, [rows[1]], client.schema.get_table(table_name)["columns"]) query = f.getvalue() expect_load_file(client, file_storage, query, table_name) db_row = client.sql_client.execute_sql(f"SELECT * FROM {canonical_name} WHERE f_int = {rows[1]['f_int']}")[0] @@ -236,7 +237,7 @@ def test_data_writer_string_escape(client: SqlJobClientBase, file_storage: FileS inj_str = f", NULL'); DROP TABLE {canonical_name} --" row["f_str"] = inj_str with io.StringIO() as f: - write_dataset(client, f, [rows[0]], rows[0].keys()) + write_dataset(client, f, [rows[0]], client.schema.get_table(table_name)["columns"]) query = f.getvalue() expect_load_file(client, file_storage, query, table_name) db_row = client.sql_client.execute_sql(f"SELECT * FROM {canonical_name}")[0] @@ -248,7 +249,7 @@ def test_data_writer_string_escape_edge(client: SqlJobClientBase, file_storage: rows, table_name = prepare_schema(client, "weird_rows") canonical_name = client.sql_client.make_qualified_table_name(table_name) with io.StringIO() as f: - write_dataset(client, f, rows, rows[0].keys()) + write_dataset(client, f, rows, client.schema.get_table(table_name)["columns"]) query = f.getvalue() expect_load_file(client, file_storage, query, table_name) for i in range(1,len(rows) + 1): @@ -267,7 +268,7 @@ def test_load_with_all_types(client: SqlJobClientBase, write_disposition: str, f canonical_name = client.sql_client.make_qualified_table_name(table_name) # write row with io.StringIO() as f: - write_dataset(client, f, [TABLE_ROW], TABLE_ROW.keys()) + write_dataset(client, f, [TABLE_ROW], TABLE_UPDATE_COLUMNS_SCHEMA) query = f.getvalue() expect_load_file(client, file_storage, query, table_name) db_row = list(client.sql_client.execute_sql(f"SELECT * FROM {canonical_name}")[0]) @@ -289,7 +290,7 @@ def test_write_dispositions(client: SqlJobClientBase, write_disposition: str, fi child_table = client.schema.normalize_make_path(table_name, "child") # add child table without write disposition so it will be inferred from the parent client.schema.update_schema( - new_table(child_table, columns=TABLE_UPDATE, parent_name=table_name) + new_table(child_table, columns=TABLE_UPDATE, parent_table_name=table_name) ) client.schema.bump_version() client.update_storage_schema() @@ -299,7 +300,7 @@ def test_write_dispositions(client: SqlJobClientBase, write_disposition: str, fi table_row = deepcopy(TABLE_ROW) table_row["col1"] = idx with io.StringIO() as f: - write_dataset(client, f, [table_row], TABLE_ROW.keys()) + write_dataset(client, f, [table_row], TABLE_UPDATE_COLUMNS_SCHEMA) query = f.getvalue() expect_load_file(client, file_storage, query, t) db_rows = list(client.sql_client.execute_sql(f"SELECT * FROM {t} ORDER BY col1 ASC")) @@ -323,24 +324,24 @@ def test_retrieve_job(client: SqlJobClientBase, file_storage: FileStorage) -> No "timestamp": str(pendulum.now()) } with io.StringIO() as f: - write_dataset(client, f, [load_json], load_json.keys()) + write_dataset(client, f, [load_json], client.schema.get_table(user_table_name)["columns"]) dataset = f.getvalue() job = expect_load_file(client, file_storage, dataset, user_table_name) # now try to retrieve the job # TODO: we should re-create client instance as this call is intended to be run after some disruption ie. stopped loader process - r_job = client.restore_file_load(file_storage._make_path(job.file_name())) + r_job = client.restore_file_load(file_storage.make_full_path(job.file_name())) assert r_job.status() == "completed" # use just file name to restore r_job = client.restore_file_load(job.file_name()) assert r_job.status() == "completed" -@pytest.mark.parametrize('client_type', ALL_CLIENT_TYPES) -def test_default_schema_name_init_storage(client_type: str) -> None: - with cm_yield_client_with_storage(client_type, initial_values={ - "DEFAULT_SCHEMA_NAME": "event" # pass the schema that is a default schema. that should create dataset with the name `DEFAULT_DATASET` +@pytest.mark.parametrize('destination_name', ALL_CLIENT_TYPES) +def test_default_schema_name_init_storage(destination_name: str) -> None: + with cm_yield_client_with_storage(destination_name, initial_values={ + "default_schema_name": "event" # pass the schema that is a default schema. that should create dataset with the name `dataset_name` }) as client: - assert client.sql_client.default_dataset_name == client.CONFIG.DEFAULT_DATASET + assert client.sql_client.default_dataset_name == client.config.dataset_name def prepare_schema(client: SqlJobClientBase, case: str) -> None: diff --git a/tests/load/test_dummy_client.py b/tests/load/test_dummy_client.py index 0330bb71bf..405bab254e 100644 --- a/tests/load/test_dummy_client.py +++ b/tests/load/test_dummy_client.py @@ -1,26 +1,26 @@ import shutil import os -from os import environ from multiprocessing.pool import ThreadPool from typing import List, Sequence, Tuple import pytest from unittest.mock import patch from prometheus_client import CollectorRegistry -from dlt.common.file_storage import FileStorage from dlt.common.exceptions import TerminalException, TerminalValueError from dlt.common.schema import Schema -from dlt.common.storages.load_storage import JobWithUnsupportedWriterException, LoadStorage -from dlt.common.typing import StrAny +from dlt.common.storages import FileStorage, LoadStorage +from dlt.common.storages.load_storage import JobWithUnsupportedWriterException from dlt.common.utils import uniq_id -from dlt.load.client_base import JobClientBase, LoadEmptyJob, LoadJob +from dlt.common.destination import DestinationReference, LoadJob -from dlt.load.configuration import configuration, ProductionLoaderConfiguration, LoaderConfiguration -from dlt.load.dummy import client -from dlt.load import Load, __version__ +from dlt.load import Load +from dlt.load.job_client_impl import LoadEmptyJob + +from dlt.load import dummy +from dlt.load.dummy import dummy as dummy_impl from dlt.load.dummy.configuration import DummyClientConfiguration -from tests.utils import clean_storage, init_logger +from tests.utils import clean_test_storage, init_logger, TEST_DICT_CONFIG_PROVIDER NORMALIZED_FILES = [ @@ -31,7 +31,7 @@ @pytest.fixture(autouse=True) def storage() -> FileStorage: - clean_storage(init_normalize=True, init_loader=True) + return clean_test_storage(init_normalize=True, init_loader=True) @pytest.fixture(scope="module", autouse=True) @@ -39,18 +39,6 @@ def logger_autouse() -> None: init_logger() -def test_gen_configuration() -> None: - load = setup_loader() - assert ProductionLoaderConfiguration not in load.CONFIG.mro() - assert LoaderConfiguration in load.CONFIG.mro() - # for production config - with patch.dict(environ, {"IS_DEVELOPMENT_CONFIG": "False"}): - # mock missing config values - load = setup_loader(initial_values={"LOAD_VOLUME_PATH": LoaderConfiguration.LOAD_VOLUME_PATH}) - assert ProductionLoaderConfiguration in load.CONFIG.mro() - assert LoaderConfiguration in load.CONFIG.mro() - - def test_spool_job_started() -> None: # default config keeps the job always running load = setup_loader() @@ -63,7 +51,7 @@ def test_spool_job_started() -> None: jobs: List[LoadJob] = [] for f in files: job = Load.w_spool_job(load, f, load_id, schema) - assert type(job) is client.LoadDummyJob + assert type(job) is dummy_impl.LoadDummyJob assert job.status() == "running" assert load.load_storage.storage.has_file(load.load_storage._get_job_file_path(load_id, LoadStorage.STARTED_JOBS_FOLDER, job.file_name())) jobs.append(job) @@ -100,7 +88,7 @@ def test_unsupported_write_disposition() -> None: def test_spool_job_failed() -> None: # this config fails job on start - load = setup_loader(initial_client_values={"FAIL_PROB" : 1.0}) + load = setup_loader(client_config=DummyClientConfiguration(fail_prob=1.0)) load_id, schema = prepare_load_package( load.load_storage, NORMALIZED_FILES @@ -125,7 +113,7 @@ def test_spool_job_failed() -> None: def test_spool_job_retry_new() -> None: # this config retries job on start (transient fail) - load = setup_loader(initial_client_values={"RETRY_PROB" : 1.0}) + load = setup_loader(client_config=DummyClientConfiguration(retry_prob=1.0)) load_id, schema = prepare_load_package( load.load_storage, NORMALIZED_FILES @@ -145,7 +133,7 @@ def test_spool_job_retry_new() -> None: def test_spool_job_retry_started() -> None: # this config keeps the job always running load = setup_loader() - client.CLIENT_CONFIG = DummyClientConfiguration + # dummy_impl.CLIENT_CONFIG = DummyClientConfiguration load_id, schema = prepare_load_package( load.load_storage, NORMALIZED_FILES @@ -154,7 +142,7 @@ def test_spool_job_retry_started() -> None: jobs: List[LoadJob] = [] for f in files: job = Load.w_spool_job(load, f, load_id, schema) - assert type(job) is client.LoadDummyJob + assert type(job) is dummy_impl.LoadDummyJob assert job.status() == "running" assert load.load_storage.storage.has_file(load.load_storage._get_job_file_path(load_id, LoadStorage.STARTED_JOBS_FOLDER, job.file_name())) # mock job config to make it retry @@ -166,7 +154,7 @@ def test_spool_job_retry_started() -> None: remaining_jobs = load.complete_jobs(load_id, jobs) assert len(remaining_jobs) == 0 # clear retry flag - client.JOBS = {} + dummy_impl.JOBS = {} files = load.load_storage.list_new_jobs(load_id) assert len(files) == 2 # parse the new job names @@ -187,10 +175,10 @@ def test_try_retrieve_job() -> None: # manually move jobs to started files = load.load_storage.list_new_jobs(load_id) for f in files: - load.load_storage.start_job(load_id, JobClientBase.get_file_name_from_file_path(f)) + load.load_storage.start_job(load_id, FileStorage.get_file_name_from_file_path(f)) # dummy client may retrieve jobs that it created itself, jobs in started folder are unknown # and returned as terminal - with load.load_client_cls(schema) as c: + with load.destination.client(schema, load.initial_client_config) as c: job_count, jobs = load.retrieve_jobs(c, load_id) assert job_count == 2 for j in jobs: @@ -204,7 +192,7 @@ def test_try_retrieve_job() -> None: jobs_count, jobs = load.spool_new_jobs(load_id, schema) assert jobs_count == 2 # now jobs are known - with load.load_client_cls(schema) as c: + with load.destination.client(schema, load.initial_client_config) as c: job_count, jobs = load.retrieve_jobs(c, load_id) assert job_count == 2 for j in jobs: @@ -212,27 +200,27 @@ def test_try_retrieve_job() -> None: def test_completed_loop() -> None: - load = setup_loader(initial_client_values={"COMPLETED_PROB": 1.0}) + load = setup_loader(client_config=DummyClientConfiguration(completed_prob=1.0)) assert_complete_job(load, load.load_storage.storage) def test_failed_loop() -> None: # ask to delete completed - load = setup_loader(initial_values={"DELETE_COMPLETED_JOBS": True}, initial_client_values={"FAIL_PROB": 1.0}) + load = setup_loader(delete_completed_jobs=True, client_config=DummyClientConfiguration(fail_prob=1.0)) # actually not deleted because one of the jobs failed assert_complete_job(load, load.load_storage.storage, should_delete_completed=False) def test_completed_loop_with_delete_completed() -> None: - load = setup_loader(initial_client_values={"COMPLETED_PROB": 1.0}) - load.CONFIG.DELETE_COMPLETED_JOBS = True + load = setup_loader(client_config=DummyClientConfiguration(completed_prob=1.0)) load.load_storage = load.create_storage(is_storage_owner=False) + load.load_storage.config.delete_completed_jobs = True assert_complete_job(load, load.load_storage.storage, should_delete_completed=True) def test_retry_on_new_loop() -> None: # test job that retries sitting in new jobs - load = setup_loader(initial_client_values={"RETRY_PROB" : 1.0}) + load = setup_loader(client_config=DummyClientConfiguration(retry_prob=1.0)) load_id, schema = prepare_load_package( load.load_storage, NORMALIZED_FILES @@ -248,7 +236,7 @@ def test_retry_on_new_loop() -> None: files = load.load_storage.list_new_jobs(load_id) assert len(files) == 2 # jobs will be completed - load = setup_loader(initial_client_values={"COMPLETED_PROB" : 1.0}) + load = setup_loader(client_config=DummyClientConfiguration(completed_prob=1.0)) load.run(ThreadPool()) files = load.load_storage.list_new_jobs(load_id) assert len(files) == 0 @@ -282,17 +270,13 @@ def test_exceptions() -> None: raise AssertionError() -def test_version() -> None: - assert configuration({"CLIENT_TYPE": "dummy"})._VERSION == __version__ - - def assert_complete_job(load: Load, storage: FileStorage, should_delete_completed: bool = False) -> None: load_id, _ = prepare_load_package( load.load_storage, NORMALIZED_FILES ) # will complete all jobs - with patch.object(client.DummyClient, "complete_load") as complete_load: + with patch.object(dummy_impl.DummyClient, "complete_load") as complete_load: load.run(ThreadPool()) # did process schema update assert storage.has_file(os.path.join(load.load_storage.get_package_path(load_id), LoadStorage.PROCESSED_SCHEMA_UPDATES_FILE_NAME)) @@ -316,33 +300,27 @@ def prepare_load_package(load_storage: LoadStorage, cases: Sequence[str]) -> Tup load_storage.create_temp_load_package(load_id) for case in cases: path = f"./tests/load/cases/loading/{case}" - shutil.copy(path, load_storage.storage._make_path(f"{load_id}/{LoadStorage.NEW_JOBS_FOLDER}")) + shutil.copy(path, load_storage.storage.make_full_path(f"{load_id}/{LoadStorage.NEW_JOBS_FOLDER}")) for f in ["schema_updates.json", "schema.json"]: path = f"./tests/load/cases/loading/{f}" - shutil.copy(path, load_storage.storage._make_path(load_id)) + shutil.copy(path, load_storage.storage.make_full_path(load_id)) load_storage.commit_temp_load_package(load_id) schema = load_storage.load_package_schema(load_id) return load_id, schema -def setup_loader(initial_values: StrAny = None, initial_client_values: StrAny = None) -> Load: +def setup_loader(delete_completed_jobs: bool = False, client_config: DummyClientConfiguration = None) -> Load: # reset jobs for a test - client.JOBS = {} - - default_values = { - "CLIENT_TYPE": "dummy", - "DELETE_COMPLETED_JOBS": False - } - default_client_values = { - "LOADER_FILE_FORMAT": "jsonl" - } - if initial_values: - default_values.update(initial_values) - if initial_client_values: - default_client_values.update(initial_client_values) + dummy_impl.JOBS = {} + destination: DestinationReference = dummy + client_config = client_config or DummyClientConfiguration(loader_file_format="jsonl") + # patch destination to provide client_config + # destination.client = lambda schema: dummy_impl.DummyClient(schema, client_config) + # setup loader - return Load( - configuration(initial_values=default_values), - CollectorRegistry(auto_describe=True), - client_initial_values=default_client_values + with TEST_DICT_CONFIG_PROVIDER().values({"delete_completed_jobs": delete_completed_jobs}): + return Load( + destination, + CollectorRegistry(auto_describe=True), + initial_client_config=client_config ) diff --git a/tests/load/utils.py b/tests/load/utils.py index 9deed5a806..15ca425daf 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -1,21 +1,23 @@ import contextlib +from importlib import import_module import os -from typing import Any, ContextManager, Iterable, Iterator, List, Sequence, cast, IO +from typing import Any, ContextManager, Iterator, List, Sequence, cast, IO from dlt.common import json, Decimal -from dlt.common.configuration import make_configuration -from dlt.common.configuration.schema_volume_configuration import SchemaVolumeConfiguration -from dlt.common.dataset_writers import write_insert_values, write_jsonl -from dlt.common.file_storage import FileStorage +from dlt.common.configuration import resolve_configuration +from dlt.common.configuration.specs import SchemaVolumeConfiguration +from dlt.common.destination import DestinationClientDwhConfiguration, DestinationReference, JobClientBase, LoadJob +from dlt.common.data_writers import DataWriter from dlt.common.schema import TColumnSchema, TTableSchemaColumns -from dlt.common.storages.schema_storage import SchemaStorage +from dlt.common.storages import SchemaStorage, FileStorage from dlt.common.schema.utils import new_table from dlt.common.time import sleep from dlt.common.typing import StrAny from dlt.common.utils import uniq_id from dlt.load import Load -from dlt.load.client_base import JobClientBase, LoadJob, SqlJobClientBase +from dlt.load.job_client_impl import SqlJobClientBase + TABLE_UPDATE: List[TColumnSchema] = [ { @@ -64,6 +66,7 @@ "nullable": False }, ] +TABLE_UPDATE_COLUMNS_SCHEMA: TTableSchemaColumns = {t["name"]:t for t in TABLE_UPDATE} TABLE_ROW = { "col1": 989127831, @@ -86,7 +89,7 @@ def expect_load_file(client: JobClientBase, file_storage: FileStorage, query: st file_name = uniq_id() file_storage.save(file_name, query.encode("utf-8")) table = Load.get_load_table(client.schema, table_name, file_name) - job = client.start_file_load(table, file_storage._make_path(file_name)) + job = client.start_file_load(table, file_storage.make_full_path(file_name)) while job.status() == "running": sleep(0.5) assert job.file_name() == file_name @@ -105,36 +108,46 @@ def prepare_table(client: JobClientBase, case_name: str = "event_user", table_na -def yield_client_with_storage(client_type: str, initial_values: StrAny = None) -> Iterator[SqlJobClientBase]: +def yield_client_with_storage(destination_name: str, initial_values: StrAny = None) -> Iterator[SqlJobClientBase]: os.environ.pop("DEFAULT_DATASET", None) + # import destination reference by name + destination: DestinationReference = import_module(f"dlt.load.{destination_name}") # create dataset with random name - default_dataset = "test_" + uniq_id() - client_initial_values = {"DEFAULT_DATASET": default_dataset} + dataset_name = "test_" + uniq_id() + # create initial config + config: DestinationClientDwhConfiguration = None + config = destination.spec()() + # print(config.destination_name) + # print(destination.spec()) + # print(destination.spec().destination_name) + config.dataset_name = dataset_name + if initial_values is not None: - client_initial_values.update(initial_values) + # apply the values to credentials, if dict is provided it will be used as initial + config.credentials = initial_values + # also apply to config + config.update(initial_values) # get event default schema - C = make_configuration(SchemaVolumeConfiguration, SchemaVolumeConfiguration, initial_values={ - "SCHEMA_VOLUME_PATH": "tests/common/cases/schemas/rasa" + C = resolve_configuration(SchemaVolumeConfiguration(), initial_value={ + "schema_volume_path": "tests/common/cases/schemas/rasa" }) schema_storage = SchemaStorage(C) schema = schema_storage.load_schema("event") # create client and dataset client: SqlJobClientBase = None - with Load.import_client_cls(client_type, initial_values=client_initial_values)(schema) as client: + + with destination.client(schema, config) as client: client.initialize_storage() yield client client.sql_client.drop_dataset() @contextlib.contextmanager -def cm_yield_client_with_storage(client_type: str, initial_values: StrAny = None) -> ContextManager[SqlJobClientBase]: - return yield_client_with_storage(client_type, initial_values) +def cm_yield_client_with_storage(destination_name: str, initial_values: StrAny = None) -> ContextManager[SqlJobClientBase]: + return yield_client_with_storage(destination_name, initial_values) -def write_dataset(client: JobClientBase, f: IO[Any], rows: Sequence[StrAny], headers: Iterable[str]) -> None: - if client.capabilities()["preferred_loader_file_format"] == "jsonl": - write_jsonl(f, rows) - elif client.capabilities()["preferred_loader_file_format"] == "insert_values": - write_insert_values(f, rows, headers) - else: - raise ValueError(client.capabilities()["preferred_loader_file_format"]) +def write_dataset(client: JobClientBase, f: IO[Any], rows: Sequence[StrAny], columns_schema: TTableSchemaColumns) -> None: + file_format = client.capabilities()["preferred_loader_file_format"] + writer = DataWriter.from_file_format(file_format, f) + writer.write_all(columns_schema, rows) diff --git a/tests/normalize/cases/ethereum_blocks_9c1d9b504ea240a482b007788d5cd61c_2.extracted.json b/tests/normalize/cases/ethereum.blocks.9c1d9b504ea240a482b007788d5cd61c_2.json similarity index 100% rename from tests/normalize/cases/ethereum_blocks_9c1d9b504ea240a482b007788d5cd61c_2.extracted.json rename to tests/normalize/cases/ethereum.blocks.9c1d9b504ea240a482b007788d5cd61c_2.json diff --git a/tests/normalize/cases/event_bot_load_metadata_1.extracted.json b/tests/normalize/cases/event.event.bot_load_metadata_2987398237498798.json similarity index 100% rename from tests/normalize/cases/event_bot_load_metadata_1.extracted.json rename to tests/normalize/cases/event.event.bot_load_metadata_2987398237498798.json diff --git a/tests/normalize/cases/event_many_load_2.extracted.json b/tests/normalize/cases/event.event.many_load_2.json similarity index 100% rename from tests/normalize/cases/event_many_load_2.extracted.json rename to tests/normalize/cases/event.event.many_load_2.json diff --git a/tests/normalize/cases/event_slot_session_metadata_1.extracted.json b/tests/normalize/cases/event.event.slot_session_metadata_1.json similarity index 100% rename from tests/normalize/cases/event_slot_session_metadata_1.extracted.json rename to tests/normalize/cases/event.event.slot_session_metadata_1.json diff --git a/tests/normalize/cases/event_user_load_1.extracted.json b/tests/normalize/cases/event.event.user_load_1.json similarity index 100% rename from tests/normalize/cases/event_user_load_1.extracted.json rename to tests/normalize/cases/event.event.user_load_1.json diff --git a/tests/normalize/cases/event_user_load_v228_1.extracted.json b/tests/normalize/cases/event.event.user_load_v228_1.json similarity index 100% rename from tests/normalize/cases/event_user_load_v228_1.extracted.json rename to tests/normalize/cases/event.event.user_load_v228_1.json diff --git a/tests/normalize/cases/schemas/ethereum_schema.json b/tests/normalize/cases/schemas/ethereum.schema.json similarity index 100% rename from tests/normalize/cases/schemas/ethereum_schema.json rename to tests/normalize/cases/schemas/ethereum.schema.json diff --git a/tests/normalize/cases/schemas/event_schema.json b/tests/normalize/cases/schemas/event.schema.json similarity index 100% rename from tests/normalize/cases/schemas/event_schema.json rename to tests/normalize/cases/schemas/event.schema.json diff --git a/tests/normalize/mock_rasa_json_normalizer.py b/tests/normalize/mock_rasa_json_normalizer.py index f6fbde5d59..e516e7527a 100644 --- a/tests/normalize/mock_rasa_json_normalizer.py +++ b/tests/normalize/mock_rasa_json_normalizer.py @@ -1,17 +1,17 @@ from dlt.common.normalizers.json import TNormalizedRowIterator from dlt.common.schema import Schema from dlt.common.normalizers.json.relational import normalize_data_item as relational_normalize, extend_schema -from dlt.common.sources import with_table_name from dlt.common.typing import TDataItem -def normalize_data_item(schema: Schema, source_event: TDataItem, load_id: str) -> TNormalizedRowIterator: +def normalize_data_item(schema: Schema, source_event: TDataItem, load_id: str, table_name: str) -> TNormalizedRowIterator: + print(f"CUSTOM NORM: {schema.name} {table_name}") if schema.name == "event": # this emulates rasa parser on standard parser - event = {"sender_id": source_event["sender_id"], "timestamp": source_event["timestamp"]} - yield from relational_normalize(schema, event, load_id) + event = {"sender_id": source_event["sender_id"], "timestamp": source_event["timestamp"], "type": source_event["event"]} + yield from relational_normalize(schema, event, load_id, table_name) # add table name which is "event" field in RASA OSS - with_table_name(source_event, "event_" + source_event["event"]) - - # will generate tables properly - yield from relational_normalize(schema, source_event, load_id) + yield from relational_normalize(schema, source_event, load_id, table_name + "_" + source_event["event"]) + else: + # will generate tables properly + yield from relational_normalize(schema, source_event, load_id, table_name) diff --git a/tests/normalize/test_normalize.py b/tests/normalize/test_normalize.py index 703baf628f..a4e3280ce7 100644 --- a/tests/normalize/test_normalize.py +++ b/tests/normalize/test_normalize.py @@ -1,25 +1,24 @@ -from typing import Dict, List, Sequence -import os import pytest -import shutil from fnmatch import fnmatch +from typing import Dict, List, Sequence from prometheus_client import CollectorRegistry +from multiprocessing import get_start_method, Pool from multiprocessing.dummy import Pool as ThreadPool from dlt.common import json +from dlt.common.destination import TLoaderFileFormat from dlt.common.utils import uniq_id from dlt.common.typing import StrAny -from dlt.common.file_storage import FileStorage from dlt.common.schema import TDataType -from dlt.common.storages.load_storage import LoadStorage -from dlt.common.storages.normalize_storage import NormalizeStorage -from dlt.common.storages import SchemaStorage -from dlt.extract.extractor_storage import ExtractorStorageBase +from dlt.common.storages import NormalizeStorage, LoadStorage +from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.configuration.container import Container -from dlt.normalize import Normalize, configuration as normalize_configuration, __version__ +from dlt.extract.extract import ExtractorStorage +from dlt.normalize import Normalize from tests.cases import JSON_TYPED_DICT, JSON_TYPED_DICT_TYPES -from tests.utils import TEST_STORAGE, assert_no_dict_key_starts_with, write_version, clean_storage, init_logger +from tests.utils import TEST_STORAGE_ROOT, TEST_DICT_CONFIG_PROVIDER, assert_no_dict_key_starts_with, write_version, clean_test_storage, init_logger from tests.normalize.utils import json_case_path @@ -38,13 +37,14 @@ def rasa_normalize() -> Normalize: def init_normalize(default_schemas_path: str = None) -> Normalize: - clean_storage() - initial = {} - if default_schemas_path: - initial = {"IMPORT_SCHEMA_PATH": default_schemas_path, "EXTERNAL_SCHEMA_FORMAT": "json"} - n = Normalize(normalize_configuration(initial), CollectorRegistry()) - # set jsonl as default writer - n.load_storage.preferred_file_format = n.CONFIG.LOADER_FILE_FORMAT = "jsonl" + clean_test_storage() + # pass schema config fields to schema storage via dict config provider + with TEST_DICT_CONFIG_PROVIDER().values({"import_schema_path": default_schemas_path, "external_schema_format": "json"}): + # inject the destination capabilities + with Container().injectable_context(DestinationCapabilitiesContext(preferred_loader_file_format="jsonl")): + n = Normalize(collector=CollectorRegistry()) + + assert n.load_storage.loader_file_format == n.loader_file_format == "jsonl" return n @@ -58,13 +58,8 @@ def test_intialize(rasa_normalize: Normalize) -> None: pass -# def test_empty_schema_name(raw_normalize: Normalize) -> None: -# schema = raw_normalize.load_or_create_schema("") -# assert schema.name == "" - - def test_normalize_single_user_event_jsonl(raw_normalize: Normalize) -> None: - expected_tables, load_files = normalize_event_user(raw_normalize, "event_user_load_1", EXPECTED_USER_TABLES) + expected_tables, load_files = normalize_event_user(raw_normalize, "event.event.user_load_1", EXPECTED_USER_TABLES) # load, parse and verify jsonl for expected_table in expected_tables: expect_lines_file(raw_normalize.load_storage, load_files[expected_table]) @@ -84,8 +79,9 @@ def test_normalize_single_user_event_jsonl(raw_normalize: Normalize) -> None: def test_normalize_single_user_event_insert(raw_normalize: Normalize) -> None: - raw_normalize.load_storage.preferred_file_format = raw_normalize.CONFIG.LOADER_FILE_FORMAT = "insert_values" - expected_tables, load_files = normalize_event_user(raw_normalize, "event_user_load_1", EXPECTED_USER_TABLES) + mock_destination_caps(raw_normalize, "insert_values") + raw_normalize.load_storage.loader_file_format = raw_normalize.loader_file_format = "insert_values" + expected_tables, load_files = normalize_event_user(raw_normalize, "event.event.user_load_1", EXPECTED_USER_TABLES) # verify values line for expected_table in expected_tables: expect_lines_file(raw_normalize.load_storage, load_files[expected_table]) @@ -101,7 +97,7 @@ def test_normalize_single_user_event_insert(raw_normalize: Normalize) -> None: def test_normalize_filter_user_event(rasa_normalize: Normalize) -> None: - load_id = normalize_cases(rasa_normalize, ["event_user_load_v228_1"]) + load_id = normalize_cases(rasa_normalize, ["event.event.user_load_v228_1"]) load_files = expect_load_package( rasa_normalize.load_storage, load_id, @@ -117,7 +113,7 @@ def test_normalize_filter_user_event(rasa_normalize: Normalize) -> None: def test_normalize_filter_bot_event(rasa_normalize: Normalize) -> None: - load_id = normalize_cases(rasa_normalize, ["event_bot_load_metadata_1"]) + load_id = normalize_cases(rasa_normalize, ["event.event.bot_load_metadata_2987398237498798"]) load_files = expect_load_package(rasa_normalize.load_storage, load_id, ["event", "event_bot"]) event_text, lines = expect_lines_file(rasa_normalize.load_storage, load_files["event_bot"], 0) assert lines == 1 @@ -127,7 +123,7 @@ def test_normalize_filter_bot_event(rasa_normalize: Normalize) -> None: def test_preserve_slot_complex_value_json_l(rasa_normalize: Normalize) -> None: - load_id = normalize_cases(rasa_normalize, ["event_slot_session_metadata_1"]) + load_id = normalize_cases(rasa_normalize, ["event.event.slot_session_metadata_1"]) load_files = expect_load_package(rasa_normalize.load_storage, load_id, ["event", "event_slot"]) event_text, lines = expect_lines_file(rasa_normalize.load_storage, load_files["event_slot"], 0) assert lines == 1 @@ -140,8 +136,8 @@ def test_preserve_slot_complex_value_json_l(rasa_normalize: Normalize) -> None: def test_preserve_slot_complex_value_insert(rasa_normalize: Normalize) -> None: - rasa_normalize.load_storage.preferred_file_format = rasa_normalize.CONFIG.LOADER_FILE_FORMAT = "insert_values" - load_id = normalize_cases(rasa_normalize, ["event_slot_session_metadata_1"]) + mock_destination_caps(rasa_normalize, "insert_values") + load_id = normalize_cases(rasa_normalize, ["event.event.slot_session_metadata_1"]) load_files = expect_load_package(rasa_normalize.load_storage, load_id, ["event", "event_slot"]) event_text, lines = expect_lines_file(rasa_normalize.load_storage, load_files["event_slot"], 2) assert lines == 3 @@ -153,33 +149,50 @@ def test_preserve_slot_complex_value_insert(rasa_normalize: Normalize) -> None: def test_normalize_raw_no_type_hints(raw_normalize: Normalize) -> None: - normalize_event_user(raw_normalize, "event_user_load_1", EXPECTED_USER_TABLES) + normalize_event_user(raw_normalize, "event.event.user_load_1", EXPECTED_USER_TABLES) assert_timestamp_data_type(raw_normalize.load_storage, "double") def test_normalize_raw_type_hints(rasa_normalize: Normalize) -> None: - normalize_cases(rasa_normalize, ["event_user_load_1"]) + normalize_cases(rasa_normalize, ["event.event.user_load_1"]) assert_timestamp_data_type(rasa_normalize.load_storage, "timestamp") def test_normalize_many_events_insert(rasa_normalize: Normalize) -> None: - rasa_normalize.load_storage.preferred_file_format = rasa_normalize.CONFIG.LOADER_FILE_FORMAT = "insert_values" - load_id = normalize_cases(rasa_normalize, ["event_many_load_2", "event_user_load_1"]) + mock_destination_caps(rasa_normalize, "insert_values") + load_id = normalize_cases(rasa_normalize, ["event.event.many_load_2", "event.event.user_load_1"]) expected_tables = EXPECTED_USER_TABLES_RASA_NORMALIZER + ["event_bot", "event_action"] load_files = expect_load_package(rasa_normalize.load_storage, load_id, expected_tables) # return first values line from event_user file event_text, lines = expect_lines_file(rasa_normalize.load_storage, load_files["event"], 4) + # 2 lines header + 3 lines data assert lines == 5 assert f"'{load_id}'" in event_text +def test_normalize_many_events(rasa_normalize: Normalize) -> None: + load_id = normalize_cases(rasa_normalize, ["event.event.many_load_2", "event.event.user_load_1"]) + expected_tables = EXPECTED_USER_TABLES_RASA_NORMALIZER + ["event_bot", "event_action"] + load_files = expect_load_package(rasa_normalize.load_storage, load_id, expected_tables) + # return first values line from event_user file + event_text, lines = expect_lines_file(rasa_normalize.load_storage, load_files["event"], 2) + # 3 lines data + assert lines == 3 + assert f"{load_id}" in event_text + + def test_normalize_many_schemas(rasa_normalize: Normalize) -> None: - rasa_normalize.load_storage.preferred_file_format = rasa_normalize.CONFIG.LOADER_FILE_FORMAT = "insert_values" - copy_cases( + mock_destination_caps(rasa_normalize, "insert_values") + extract_cases( rasa_normalize.normalize_storage, - ["event_many_load_2", "event_user_load_1", "ethereum_blocks_9c1d9b504ea240a482b007788d5cd61c_2"] + ["event.event.many_load_2", "event.event.user_load_1", "ethereum.blocks.9c1d9b504ea240a482b007788d5cd61c_2"] ) - rasa_normalize.run(ThreadPool(processes=4)) + if get_start_method() != "fork": + # windows, mac os do not support fork + rasa_normalize.run(ThreadPool(processes=4)) + else: + # linux does so use real process pool in tests + rasa_normalize.run(Pool(processes=4)) # must have two loading groups with model and event schemas loads = rasa_normalize.load_storage.list_packages() assert len(loads) == 2 @@ -198,8 +211,8 @@ def test_normalize_many_schemas(rasa_normalize: Normalize) -> None: def test_normalize_typed_json(raw_normalize: Normalize) -> None: - raw_normalize.load_storage.preferred_file_format = raw_normalize.CONFIG.LOADER_FILE_FORMAT = "jsonl" - extract_items(raw_normalize.normalize_storage, [JSON_TYPED_DICT], "special") + mock_destination_caps(raw_normalize, "jsonl") + extract_items(raw_normalize.normalize_storage, [JSON_TYPED_DICT], "special", "special") raw_normalize.run(ThreadPool(processes=1)) loads = raw_normalize.load_storage.list_packages() assert len(loads) == 1 @@ -224,17 +237,13 @@ def test_normalize_typed_json(raw_normalize: Normalize) -> None: "event__parse_data__response_selector__default__response__responses"] -def extract_items(normalize_storage: NormalizeStorage, items: Sequence[StrAny], schema_name: str) -> None: - extractor = ExtractorStorageBase("1.0.0", True, FileStorage(os.path.join(TEST_STORAGE, "extractor"), makedirs=True), normalize_storage) - load_id = uniq_id() - extractor.save_json(f"{load_id}.json", items) - extractor.commit_events( - schema_name, - extractor.storage._make_path(f"{load_id}.json"), - "items", - len(items), - load_id - ) +def extract_items(normalize_storage: NormalizeStorage, items: Sequence[StrAny], schema_name: str, table_name: str) -> None: + extractor = ExtractorStorage(normalize_storage.config) + extract_id = extractor.create_extract_id() + extractor.write_data_item(extract_id, schema_name, table_name, items, None) + extractor.close_writers(extract_id) + extractor.commit_extract_files(extract_id) + def normalize_event_user(normalize: Normalize, case: str, expected_user_tables: List[str] = None) -> None: expected_user_tables = expected_user_tables or EXPECTED_USER_TABLES_RASA_NORMALIZER @@ -243,21 +252,23 @@ def normalize_event_user(normalize: Normalize, case: str, expected_user_tables: def normalize_cases(normalize: Normalize, cases: Sequence[str]) -> str: - copy_cases(normalize.normalize_storage, cases) + extract_cases(normalize.normalize_storage, cases) load_id = uniq_id() normalize.load_storage.create_temp_load_package(load_id) # pool not required for map_single - dest_cases = [f"{NormalizeStorage.EXTRACTED_FOLDER}/{c}.extracted.json" for c in cases] + dest_cases = normalize.normalize_storage.storage.list_folder_files(NormalizeStorage.EXTRACTED_FOLDER) # [f"{NormalizeStorage.EXTRACTED_FOLDER}/{c}.extracted.json" for c in cases] # create schema if it does not exist - normalize.load_or_create_schema("event") + Normalize.load_or_create_schema(normalize.schema_storage, "event") normalize.spool_files("event", load_id, normalize.map_single, dest_cases) return load_id -def copy_cases(normalize_storage: NormalizeStorage, cases: Sequence[str]) -> None: +def extract_cases(normalize_storage: NormalizeStorage, cases: Sequence[str]) -> None: for case in cases: - event_user_path = json_case_path(f"{case}.extracted") - shutil.copy(event_user_path, normalize_storage.storage._make_path(NormalizeStorage.EXTRACTED_FOLDER)) + schema_name, table_name, _ = NormalizeStorage.parse_normalize_file_name(case + ".jsonl") + with open(json_case_path(case), "r", encoding="utf-8") as f: + items = json.load(f) + extract_items(normalize_storage, items, schema_name, table_name) def expect_load_package(load_storage: LoadStorage, load_id: str, expected_tables: Sequence[str]) -> Dict[str, str]: @@ -266,7 +277,7 @@ def expect_load_package(load_storage: LoadStorage, load_id: str, expected_tables ofl: Dict[str, str] = {} for expected_table in expected_tables: # find all files for particular table, ignoring file id - file_mask = load_storage.build_job_file_name(expected_table, "*") + file_mask = load_storage.build_job_file_name(expected_table, "*", validate_components=False) # files are in normalized//new_jobs file_path = load_storage._get_job_file_path(load_id, "new_jobs", file_mask) candidates = [f for f in files if fnmatch(f, file_path)] @@ -289,5 +300,7 @@ def assert_timestamp_data_type(load_storage: LoadStorage, data_type: TDataType) assert event_schema.get_table_columns("event")["timestamp"]["data_type"] == data_type -def test_version() -> None: - assert normalize_configuration()._VERSION == __version__ +def mock_destination_caps(n: Normalize, loader_file_format: TLoaderFileFormat) -> None: + # mock the loader file format + # TODO: mock full capabilities here + n.load_storage.loader_file_format = n.loader_file_format = loader_file_format diff --git a/tests/tools/create_storages.py b/tests/tools/create_storages.py index 7606f433be..efbc8cf8ff 100644 --- a/tests/tools/create_storages.py +++ b/tests/tools/create_storages.py @@ -1,11 +1,7 @@ -from dlt.common.storages.normalize_storage import NormalizeStorage -from dlt.common.configuration import NormalizeVolumeConfiguration -from dlt.common.storages.load_storage import LoadStorage -from dlt.common.configuration import LoadVolumeConfiguration -from dlt.common.storages.schema_storage import SchemaStorage -from dlt.common.configuration import SchemaVolumeConfiguration +from dlt.common.storages import NormalizeStorage, LoadStorage, SchemaStorage +from dlt.common.configuration.specs import NormalizeVolumeConfiguration, LoadVolumeConfiguration, SchemaVolumeConfiguration -NormalizeStorage(True, NormalizeVolumeConfiguration) -LoadStorage(True, LoadVolumeConfiguration, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) -SchemaStorage(SchemaVolumeConfiguration.SCHEMA_VOLUME_PATH, makedirs=True) +# NormalizeStorage(True, NormalizeVolumeConfiguration) +# LoadStorage(True, LoadVolumeConfiguration, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) +# SchemaStorage(SchemaVolumeConfiguration, makedirs=True) diff --git a/tests/utils.py b/tests/utils.py index be4d1b544c..17dc6884da 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,21 +1,34 @@ import multiprocessing import platform +from typing import Any, Mapping import requests -from typing import Type import pytest import logging from os import environ -from dlt.common.configuration.utils import _get_config_attrs_with_hints, make_configuration -from dlt.common.configuration import RunConfiguration +from dlt.common.configuration.container import Container +from dlt.common.configuration.providers import EnvironProvider, DictionaryProvider +from dlt.common.configuration.resolve import resolve_configuration, serialize_value +from dlt.common.configuration.specs import BaseConfiguration, RunConfiguration +from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext from dlt.common.logger import init_logging_from_config -from dlt.common.file_storage import FileStorage +from dlt.common.storages import FileStorage from dlt.common.schema import Schema from dlt.common.storages.versioned_storage import VersionedStorage from dlt.common.typing import StrAny -TEST_STORAGE = "_storage" +TEST_STORAGE_ROOT = "_storage" + +# add test dictionary provider +def TEST_DICT_CONFIG_PROVIDER(): + providers_context = Container()[ConfigProvidersContext] + try: + return providers_context[DictionaryProvider.NAME] + except KeyError: + provider = DictionaryProvider() + providers_context.add_provider(provider) + return provider class MockHttpResponse(): @@ -31,20 +44,20 @@ def write_version(storage: FileStorage, version: str) -> None: storage.save(VersionedStorage.VERSION_FILE, str(version)) -def delete_storage() -> None: - storage = FileStorage(TEST_STORAGE) +def delete_test_storage() -> None: + storage = FileStorage(TEST_STORAGE_ROOT) if storage.has_folder(""): storage.delete_folder("", recursively=True) @pytest.fixture() -def root_storage() -> FileStorage: - return clean_storage() +def test_storage() -> FileStorage: + return clean_test_storage() @pytest.fixture(autouse=True) -def autouse_root_storage() -> FileStorage: - return clean_storage() +def autouse_test_storage() -> FileStorage: + return clean_test_storage() @pytest.fixture(scope="module", autouse=True) @@ -55,37 +68,37 @@ def preserve_environ() -> None: environ.update(saved_environ) -def init_logger(C: Type[RunConfiguration] = None) -> None: +def init_logger(C: RunConfiguration = None) -> None: if not hasattr(logging, "health"): if not C: - C = make_configuration(RunConfiguration, RunConfiguration) + C = resolve_configuration(RunConfiguration()) init_logging_from_config(C) -def clean_storage(init_normalize: bool = False, init_loader: bool = False) -> FileStorage: - storage = FileStorage(TEST_STORAGE, "t", makedirs=True) +def clean_test_storage(init_normalize: bool = False, init_loader: bool = False) -> FileStorage: + storage = FileStorage(TEST_STORAGE_ROOT, "t", makedirs=True) storage.delete_folder("", recursively=True) storage.create_folder(".") if init_normalize: - from dlt.common.storages.normalize_storage import NormalizeStorage - from dlt.common.configuration import NormalizeVolumeConfiguration - NormalizeStorage(True, NormalizeVolumeConfiguration) + from dlt.common.storages import NormalizeStorage + NormalizeStorage(True) if init_loader: - from dlt.common.storages.load_storage import LoadStorage - from dlt.common.configuration import LoadVolumeConfiguration - LoadStorage(True, LoadVolumeConfiguration, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) + from dlt.common.storages import LoadStorage + LoadStorage(True, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) return storage -def add_config_to_env(config: Type[RunConfiguration]) -> None: +def add_config_to_env(config: BaseConfiguration) -> None: # write back default values in configuration back into environment - possible_attrs = _get_config_attrs_with_hints(config).keys() - for attr in possible_attrs: - if attr not in environ: - v = getattr(config, attr) + return add_config_dict_to_env(dict(config), config.__namespace__) + + +def add_config_dict_to_env(dict_: Mapping[str, Any], namespace: str = None, overwrite_keys: bool = False) -> None: + for k, v in dict_.items(): + env_key = EnvironProvider.get_key_name(k, namespace) + if env_key not in environ or overwrite_keys: if v is not None: - # print(f"setting {attr} to {v}") - environ[attr] = str(v) + environ[env_key] = serialize_value(v) def create_schema_with_name(schema_name) -> Schema: @@ -103,4 +116,5 @@ def assert_no_dict_key_starts_with(d: StrAny, key_prefix: str) -> None: skipifpypy = pytest.mark.skipif( platform.python_implementation() == "PyPy", reason="won't run in PyPy interpreter" -) \ No newline at end of file +) +