Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update: add some pep8 and fix typehint #345

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 90 additions & 66 deletions src/environs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,44 @@
import logging
import os
import re
import typing
from collections.abc import Mapping
from enum import Enum
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
cast,
)
from urllib.parse import ParseResult, urlparse

import marshmallow as ma
from dotenv.main import _walk_to_root, load_dotenv

__all__ = ["EnvError", "Env"]

_T = typing.TypeVar("_T")
_T = TypeVar("_T")
_StrType = str
_BoolType = bool

ErrorMapping = typing.Mapping[str, typing.List[str]]
ErrorList = typing.List[str]
FieldFactory = typing.Callable[..., ma.fields.Field]
Subcast = typing.Union[typing.Type, typing.Callable[..., _T], ma.fields.Field]
FieldType = typing.Type[ma.fields.Field]
FieldOrFactory = typing.Union[FieldType, FieldFactory]
ParserMethod = typing.Callable
ErrorMapping = Mapping[str, List[str]]
ErrorList = List[str]
FieldFactory = Callable[..., ma.fields.Field]
Subcast = Union[Type, Callable[..., _T], ma.fields.Field]
FieldType = Type[ma.fields.Field]
FieldOrFactory = Union[FieldType, FieldFactory]
ParserMethod = Callable


_EXPANDED_VAR_PATTERN = re.compile(r"(?<!\\)\$\{([A-Za-z0-9_]+)(:-[^\}:]*)?\}")
_EXPANDED_VAR_PATTERN = re.compile(r"(?<!\\)\$\{([A-Za-z0-9_]+)(:-[^}:]*)?}")


class EnvError(ValueError):
Expand All @@ -41,7 +54,9 @@ class EnvError(ValueError):

class EnvValidationError(EnvError):
def __init__(
self, message: str, error_messages: typing.Union[ErrorList, ErrorMapping]
self,
message: str,
error_messages: Union[ErrorList, ErrorMapping],
):
self.error_messages = error_messages
super().__init__(message)
Expand All @@ -61,29 +76,29 @@ def _field2method(
field_or_factory: FieldOrFactory,
method_name: str,
*,
preprocess: typing.Optional[typing.Callable] = None,
preprocess_kwarg_names: typing.Sequence[str] = tuple(),
preprocess: Optional[Callable] = None,
preprocess_kwarg_names: Sequence[str] = tuple(),
) -> ParserMethod:
def method(
self: "Env",
name: str,
default: typing.Any = ma.missing,
subcast: typing.Optional[Subcast] = None,
default: Any = ma.missing,
subcast: Optional[Subcast] = None,
*,
# Subset of relevant marshmallow.Field kwargs
load_default: typing.Any = ma.missing,
validate: typing.Optional[
typing.Union[
typing.Callable[[typing.Any], typing.Any],
typing.Iterable[typing.Callable[[typing.Any], typing.Any]],
load_default: Any = ma.missing,
validate: Optional[
Union[
Callable[[Any], Any],
Iterable[Callable[[Any], Any]],
]
] = None,
required: bool = False,
allow_none: typing.Optional[bool] = None,
error_messages: typing.Optional[typing.Dict[str, str]] = None,
metadata: typing.Optional[typing.Mapping[str, typing.Any]] = None,
allow_none: Optional[bool] = None,
error_messages: Optional[Dict[str, str]] = None,
metadata: Optional[Mapping[str, Any]] = None,
**kwargs,
) -> typing.Optional[_T]:
) -> Optional[_T]:
if self._sealed:
raise EnvSealedError(
"Env has already been sealed. New values cannot be parsed."
Expand All @@ -100,14 +115,16 @@ def method(
name: kwargs.pop(name) for name in preprocess_kwarg_names if name in kwargs
}
if isinstance(field_or_factory, type) and issubclass(
field_or_factory, ma.fields.Field
field_or_factory,
ma.fields.Field,
):
field = field_or_factory(**field_kwargs, **kwargs)
else:
parsed_subcast = _make_subcast_field(subcast)
field = field_or_factory(subcast=parsed_subcast, **field_kwargs)
parsed_key, value, proxied_key = self._get_from_environ(
name, field.load_default
name,
field.load_default,
)
self._fields[parsed_key] = field
source_key = proxied_key or parsed_key
Expand Down Expand Up @@ -138,13 +155,13 @@ def method(
return method


def _func2method(func: typing.Callable, method_name: str) -> ParserMethod:
def _func2method(func: Callable, method_name: str) -> ParserMethod:
def method(
self: "Env",
name: str,
default: typing.Any = ma.missing,
default: Any = ma.missing,
**kwargs,
) -> typing.Optional[_T]:
) -> Optional[_T]:
if self._sealed:
raise EnvSealedError(
"Env has already been sealed. New values cannot be parsed."
Expand Down Expand Up @@ -187,8 +204,8 @@ def method(


def _make_subcast_field(
subcast: typing.Optional[Subcast],
) -> typing.Type[ma.fields.Field]:
subcast: Optional[Subcast],
) -> Type[ma.fields.Field]:
if isinstance(subcast, type) and subcast in ma.Schema.TYPE_MAPPING:
inner_field = ma.Schema.TYPE_MAPPING[subcast]
elif isinstance(subcast, type) and issubclass(subcast, ma.fields.Field):
Expand All @@ -197,7 +214,7 @@ def _make_subcast_field(

class SubcastField(ma.fields.Field):
def _deserialize(self, value, *args, **kwargs):
func = typing.cast(typing.Callable[..., _T], subcast)
func = cast(Callable[..., _T], subcast)
return func(value)

inner_field = SubcastField
Expand All @@ -206,43 +223,43 @@ def _deserialize(self, value, *args, **kwargs):
return inner_field


def _make_list_field(*, subcast: typing.Optional[type], **kwargs) -> ma.fields.List:
def _make_list_field(*, subcast: Optional[type], **kwargs) -> ma.fields.List:
inner_field = _make_subcast_field(subcast)
return ma.fields.List(inner_field, **kwargs)


def _preprocess_list(
value: typing.Union[str, typing.Iterable], *, delimiter: str = ",", **kwargs
) -> typing.Iterable:
value: Union[str, Iterable], *, delimiter: str = ",", **kwargs
) -> None | list[str] | list[Any]:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we support python 3.8, we can't use this style of typing quite yet

if ma.utils.is_iterable_but_not_string(value) or value is None:
return value
return typing.cast(str, value).split(delimiter) if value != "" else []
return cast(str, value).split(delimiter) if value != "" else []


def _preprocess_dict(
value: typing.Union[str, typing.Mapping],
value: Union[str, Mapping],
*,
subcast_keys: typing.Optional[Subcast] = None,
subcast_values: typing.Optional[Subcast] = None,
subcast_keys: Optional[Subcast] = None,
subcast_values: Optional[Subcast] = None,
delimiter: str = ",",
**kwargs,
) -> typing.Mapping:
) -> Mapping:
if isinstance(value, Mapping):
return value
subcast_keys_instance: ma.fields.Field = _make_subcast_field(subcast_keys)(**kwargs)
subcast_values_instance: ma.fields.Field = _make_subcast_field(subcast_values)(
**kwargs
**kwargs,
)

return {
subcast_keys_instance.deserialize(
key.strip()
key.strip(),
): subcast_values_instance.deserialize(val.strip())
for key, val in (item.split("=", 1) for item in value.split(delimiter) if value)
}


def _preprocess_json(value: typing.Union[str, typing.Mapping, typing.List], **kwargs):
def _preprocess_json(value: Union[str, Mapping, List], **kwargs):
try:
if isinstance(value, str):
return pyjson.loads(value)
Expand All @@ -254,19 +271,19 @@ def _preprocess_json(value: typing.Union[str, typing.Mapping, typing.List], **kw
raise ma.ValidationError("Not valid JSON.") from error


_EnumT = typing.TypeVar("_EnumT", bound=Enum)
_EnumT = TypeVar("_EnumT", bound=Enum)


def _enum_parser(value, type: typing.Type[_EnumT], ignore_case: bool = False) -> _EnumT:
invalid_exc = ma.ValidationError(f"Not a valid '{type.__name__}' enum.")
def _enum_parser(value, _type: Type[_EnumT], ignore_case: bool = False) -> _EnumT:
invalid_exc = ma.ValidationError(f"Not a valid '{_type.__name__}' enum.")

if not ignore_case:
try:
return type[value]
return _type[value]
except KeyError as error:
raise invalid_exc from error

for enum_value in type:
for enum_value in _type:
if enum_value.name.lower() == value.lower():
return enum_value

Expand Down Expand Up @@ -326,8 +343,8 @@ def _serialize(self, value: ParseResult, *args, **kwargs) -> str:
def deserialize(
self,
value: str,
attr: typing.Optional[str] = None,
data: typing.Optional[typing.Mapping] = None,
attr: Optional[str] = None,
data: Optional[Mapping] = None,
**kwargs,
) -> ParseResult:
ret = super().deserialize(value, attr, data, **kwargs)
Expand Down Expand Up @@ -400,18 +417,18 @@ def __init__(self, *, eager: _BoolType = True, expand_vars: _BoolType = False):
self.eager = eager
self._sealed: bool = False
self.expand_vars = expand_vars
self._fields: typing.Dict[_StrType, typing.Union[ma.fields.Field, type]] = {}
self._values: typing.Dict[_StrType, typing.Any] = {}
self._fields: Dict[_StrType, Union[ma.fields.Field, type]] = {}
self._values: Dict[_StrType, Any] = {}
self._errors: ErrorMapping = collections.defaultdict(list)
self._prefix: typing.Optional[_StrType] = None
self.__custom_parsers__: typing.Dict[_StrType, ParserMethod] = {}
self._prefix: Optional[_StrType] = None
self.__custom_parsers__: Dict[_StrType, ParserMethod] = {}

def __repr__(self) -> _StrType:
return f"<{self.__class__.__name__}(eager={self.eager}, expand_vars={self.expand_vars})>" # noqa: E501

@staticmethod
def read_env(
path: typing.Optional[_StrType] = None,
path: Optional[_StrType] = None,
recurse: _BoolType = True,
verbose: _BoolType = False,
override: _BoolType = False,
Expand Down Expand Up @@ -449,7 +466,7 @@ def read_env(
return load_dotenv(str(start), verbose=verbose, override=override)

@contextlib.contextmanager
def prefixed(self, prefix: _StrType) -> typing.Iterator["Env"]:
def prefixed(self, prefix: _StrType) -> Iterator["Env"]:
"""Context manager for parsing envvars with a common prefix."""
try:
old_prefix = self._prefix
Expand Down Expand Up @@ -482,7 +499,7 @@ def __getattr__(self, name: _StrType):
except KeyError as error:
raise AttributeError(f"{self} has no attribute {name}") from error

def add_parser(self, name: _StrType, func: typing.Callable) -> None:
def add_parser(self, name: _StrType, func: Callable) -> None:
"""Register a new parser method with the name ``name``. ``func`` must
receive the input value for an environment variable.
"""
Expand All @@ -494,36 +511,43 @@ def add_parser(self, name: _StrType, func: typing.Callable) -> None:
return None

def parser_for(
self, name: _StrType
) -> typing.Callable[[typing.Callable], typing.Callable]:
self,
name: _StrType,
) -> Callable[[Callable], Callable]:
"""Decorator that registers a new parser method with the name ``name``.
The decorated function must receive the input value for an environment variable.
"""

def decorator(func: typing.Callable) -> typing.Callable:
def decorator(func: Callable) -> Callable:
self.add_parser(name, func)
return func

return decorator

def add_parser_from_field(
self, name: _StrType, field_cls: typing.Type[ma.fields.Field]
self,
name: _StrType,
field_cls: Type[ma.fields.Field],
):
"""Register a new parser method with name ``name``,
given a marshmallow ``Field``.
"""
self.__custom_parsers__[name] = _field2method(field_cls, method_name=name)

def dump(self) -> typing.Mapping[_StrType, typing.Any]:
def dump(self) -> Mapping[_StrType, Any]:
"""Dump parsed environment variables to a dictionary of simple data types
(numbers and strings).
"""
schema = ma.Schema.from_dict(self._fields)()
return schema.dump(self._values)

def _get_from_environ(
self, key: _StrType, default: typing.Any, *, proxied: _BoolType = False
) -> typing.Tuple[_StrType, typing.Any, typing.Optional[_StrType]]:
self,
key: _StrType,
default: Any,
*,
proxied: _BoolType = False,
) -> Tuple[_StrType, Any, Optional[_StrType]]:
"""Access a value from os.environ. Handles proxied variables,
e.g. SMTP_LOGIN={{MAILGUN_LOGIN}}.

Expand All @@ -540,7 +564,7 @@ def _get_from_environ(
expand_match = self.expand_vars and _EXPANDED_VAR_PATTERN.match(value)
if expand_match: # Full match expand_vars - special case keep default
proxied_key: _StrType = expand_match.groups()[0]
subs_default: typing.Optional[_StrType] = expand_match.groups()[1]
subs_default: Optional[_StrType] = expand_match.groups()[1]
if subs_default is not None:
default = subs_default[2:]
elif (
Expand Down
Loading