Skip to content

Commit

Permalink
feature(loaders): Factored out loaders.py.
Browse files Browse the repository at this point in the history
Simplifies ``CreateYamlSettings``.
Updated ``loaders.py`` with recent updates from ``release/v2``.
  • Loading branch information
acederberg committed Jul 16, 2024
1 parent 148d78b commit 9794b6f
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 266 deletions.
248 changes: 28 additions & 220 deletions yaml_settings_pydantic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,15 @@
from pathlib import Path, PosixPath
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal, TypeVar

from jsonpath_ng import parse
from pydantic.fields import FieldInfo
from pydantic.v1.utils import deep_update
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
SettingsConfigDict,
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource
from typing_extensions import Doc

from yaml_settings_pydantic import loader
from yaml_settings_pydantic.loader import (
DEFAULT_YAML_FILE_CONFIG_DICT,
YamlSettingsConfigDict,
)
from typing_extensions import Doc, NotRequired, TypedDict
from yaml import safe_load

__version__ = "2.3.1"
logger = logging.getLogger("yaml_settings_pydantic")
Expand All @@ -40,95 +39,6 @@
T = TypeVar("T")


class YamlFileConfigDict(TypedDict, total=False):
# NOTE: ``NotRequired``
envvar: NotRequired[
Annotated[
str | None,
Doc(
"Env variable for the configuration path. If this env variable "
"is defined it will overwrite the path to which this dict is "
"associated within ``YamlSettingsConfigDict.yaml_files`` via keys."
),
]
]

subpath: NotRequired[
Annotated[
str | None,
Doc("The configuration subpath of the file (using json path)."),
]
]

required: NotRequired[
Annotated[
bool,
Doc("The file specified is required."),
]
]


class YamlFileData(TypedDict):
config: Annotated[
YamlFileConfigDict,
Doc("Configuration from which this data was ascertained."),
]
source: Annotated[
Path,
Doc(
"Origin of the content. This is here because environment "
"variables can overwrite the source path (provided in "
"``YamlSettingsConfigDict.files``)."
),
]
content: Annotated[
Any,
Doc("Content loaded from :attr:`source`."),
]


DEFAULT_YAML_FILE_CONFIG_DICT = YamlFileConfigDict(
envvar=None, subpath=None, required=True
)


class YamlSettingsConfigDict(SettingsConfigDict, TypedDict):
yaml_files: Annotated[
set[Path]
| Sequence[Path]
| dict[Path, YamlFileConfigDict]
| Path
| set[str]
| Sequence[str]
| dict[str, YamlFileConfigDict]
| str,
Doc(
"Files to load. This can be a ``str`` or ``Sequence`` of "
"configuration paths, or a dictionary of file names mapping to "
"their options. This data is hydrated by ``CreateYamlSettings`` "
"into the dictionary form ``dict[str, YamlFileConfigDict]`` no "
"matter the form in which it is provided."
),
]

yaml_reload: NotRequired[
Annotated[
bool | None,
Doc("Reload files on object construction when ``True``."),
]
]


def resolve_filepaths(fp: Path, fp_config: YamlFileConfigDict) -> Path:

fp_from_env = None
if (fp_env_var := fp_config.get("envvar")) is not None:
fp_from_env = environ.get(fp_env_var)

fp_final = fp if not fp_from_env else Path(fp_from_env)
return fp_final


class CreateYamlSettings(PydanticBaseSettingsSource):
"""Create a ``yaml`` setting loader middleware.
Expand All @@ -140,7 +50,7 @@ class CreateYamlSettings(PydanticBaseSettingsSource):
# Info

files: Annotated[
dict[Path, YamlFileConfigDict],
dict[Path, loader.YamlFileConfigDict],
Doc(
"``YAML`` or ``JSON`` files to load and loading specifications ("
"in the form of :class:`YamlFileConfigDict`)."
Expand Down Expand Up @@ -218,10 +128,10 @@ def validate_reload(self, settings_cls: type[BaseSettings]) -> bool:

def validate_files(
self, settings_cls: type[BaseSettings]
) -> dict[Path, YamlFileConfigDict]:
) -> dict[Path, loader.YamlFileConfigDict]:
"""Validate ``model_config["files"]``."""

found_value: dict[Path, YamlFileConfigDict] | str | Sequence[str] | None
found_value: dict[Path, loader.YamlFileConfigDict] | str | Sequence[str] | None
found_value = self.get_settings_cls_value(settings_cls, "files", None)
item = f"{settings_cls.__name__}.model_config.yaml_files"

Expand All @@ -246,8 +156,8 @@ def validate_files(
# else just leave it.
values: (
tuple[Path, ...]
| dict[str, YamlFileConfigDict]
| dict[Path, YamlFileConfigDict]
| dict[str, loader.YamlFileConfigDict]
| dict[Path, loader.YamlFileConfigDict]
)
if isinstance(found_value, PosixPath):
logger.debug(f"`{item}` was a PosixPath.")
Expand All @@ -266,12 +176,12 @@ def validate_files(
)

# NOTE: Create dictionary if the sequence is not a dictionary.
files: dict[Path, YamlFileConfigDict]
files: dict[Path, loader.YamlFileConfigDict]
if not isinstance(values, dict):
files = {
(
k if isinstance(k, Path) else Path(k)
): DEFAULT_YAML_FILE_CONFIG_DICT.copy()
): loader.DEFAULT_YAML_FILE_CONFIG_DICT.copy()
for k in values
}
elif any(not isinstance(v, dict) for v in values.values()):
Expand All @@ -280,7 +190,7 @@ def validate_files(
raise ValueError("`files` cannot have length `0`.")
else:
for k, v in values.items():
vv = DEFAULT_YAML_FILE_CONFIG_DICT.copy()
vv = loader.DEFAULT_YAML_FILE_CONFIG_DICT.copy()
vv.update(v)
values[k] = v
files = values
Expand Down Expand Up @@ -331,124 +241,13 @@ def get_settings_cls_value(
# ----------------------------------------------------------------------- #
# Loading

def validate_yaml_data_content(
self,
fp: Path,
fp_data: YamlFileData,
) -> tuple[dict[str, Any], Path | None]:

fp_config = fp_data["config"]
content = fp_data["content"]

if (subpath := fp_config.get("subpath")) is not None:
jsonpath_exp = parse(subpath)

extracted = next(iter(jsonpath_exp.find(content)), None)
if extracted is None:
msg = f"Could not find path `{subpath}` in `{fp}`."
raise ValueError(msg)

extracted = extracted.value
else:
extracted = content

return extracted, None if isinstance(content, dict) else fp

def validate_yaml_data(
self,
yaml_data: dict[Path, YamlFileData],
) -> dict[str, Any]:
"""Extract subpath from loaded YAML.
:param loaded: Loaded YAML files from :attr:`files`.
:raises: `ValueError` when the subpaths cannot be found or when
documents do not deserialize to dictionaries at their subpath.
:returns: :param:`Loaded` with the subpath extracted.
"""

if not yaml_data:
return dict()

# NOTE: ``dict`` is included for the case where ``loaded`` has 0 length.
content: tuple[dict[str, Any], ...]
fp_invalid_unfiltered: tuple[Path | None, ...]

content, fp_invalid_unfiltered = zip(
*(
self.validate_yaml_data_content(fp, fp_data)
for fp, fp_data in yaml_data.items()
),
)

fp_invalid = tuple(fp for fp in fp_invalid_unfiltered if fp is not None)
if len(fp_invalid):
fmt = " - `file={0}`\n`subpath={1}`"
msg = "\n".join(
fmt.format(fp, yaml_data[fp].get("subpath")) for fp in fp_invalid
)
msg = (
"Input files must deserialize to dictionaries at their "
f"specified subpaths:\n{msg}"
)
raise ValueError(msg)

logger.debug("Merging file results.")
return deep_update(*content)

def load_yaml_data(self) -> dict[Path, YamlFileData]:
"""Load data without validatation."""

# NOTE: Check that required files exist. Find existing files and handle
# environment variable overwrites.
filepaths: dict[tuple[Path, Path], YamlFileConfigDict]
filepaths = {
(fp_default, resolve_filepaths(fp_default, fp_config)): fp_config
for fp_default, fp_config in self.files.items()
}

# NOTE: No files to check.
if not len(filepaths):
return dict()

# NOTE: If any required files are missing, raise an error.
fp_resolved_required_missing = {
fp_resolved
for (_, fp_resolved), fp_config in filepaths.items()
if fp_config.get("required") and not fp_resolved.is_file()
}
if len(fp_resolved_required_missing):
raise ValueError(
"The following files are required but do not exist: "
f"`{fp_resolved_required_missing}`."
)

# NOTE: Bulk load files (and bulk manage IO closing/opening).
# logger.debug("Loading files %s.", ", ".join(map(str, self.files)))
files = {
(fp_default, fp_resolved): Path.open(fp_resolved)
for (fp_default, fp_resolved) in filepaths
if fp_resolved.exists()
}
yaml_data: dict[Path, YamlFileData] = {
fp_default: YamlFileData(
content=safe_load(stream),
source=fp_default,
config=filepaths[(fp_default, fp_resolved)],
)
for (fp_default, fp_resolved), stream in files.items()
}
logger.debug("Closing files.")
_ = {file.close() for file in files.values()} # type: ignore

return yaml_data

def load(self) -> dict[str, Any]:
"""Load data and validate that it is sufficiently shaped for
``BaseSettings``.
"""

self._yaml_data = (yaml_data := self.load_yaml_data())
return self.validate_yaml_data(yaml_data)
self._yaml_data = (yaml_data := loader.load_yaml_data(self.files))
return loader.validate_yaml_data(yaml_data)


class BaseYamlSettings(BaseSettings):
Expand All @@ -469,7 +268,7 @@ class BaseYamlSettings(BaseSettings):
if TYPE_CHECKING:
# NOTE: pydantic>=2.7 checks at load time for annotated fields, and
# thinks that `model_config` is a model field name.
model_config: ClassVar[YamlSettingsConfigDict]
model_config: ClassVar[loader.YamlSettingsConfigDict]

__yaml_files__: ClassVar[Sequence[str] | None]
__yaml_reload__: ClassVar[bool | None]
Expand Down Expand Up @@ -500,4 +299,13 @@ def settings_customise_sources(
)


__all__ = ("CreateYamlSettings", "YamlSettingsConfigDict", "BaseYamlSettings")
from yaml_settings_pydantic.loader import YamlFileConfigDict, resolve_filepaths

__all__ = (
"resolve_filepaths",
"CreateYamlSettings",
"BaseYamlSettings",
"YamlSettingsConfigDict",
"YamlFileConfigDict",
"DEFAULT_YAML_FILE_CONFIG_DICT",
)
Loading

0 comments on commit 9794b6f

Please sign in to comment.