Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added type hints for methods #371

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 130 additions & 29 deletions src/environs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,41 @@
import builtins
import collections
import contextlib
import datetime as dt
import decimal
import functools
import inspect
import json as pyjson
import logging
import os
import re
import typing
import uuid
from collections.abc import Mapping
from datetime import timedelta
from enum import Enum
from pathlib import Path
from urllib.parse import ParseResult, urlparse

import marshmallow as ma
from dj_database_url import DBConfig
from dotenv.main import _walk_to_root, load_dotenv

__all__ = ["EnvError", "Env"]

_T = typing.TypeVar("_T")
_StrType = str
_BoolType = bool
_EnumT = typing.TypeVar("_EnumT", bound=Enum)


ErrorMapping = typing.Mapping[str, typing.List[str]]
ErrorList = typing.List[str]
FieldFactory = typing.Callable[..., ma.fields.Field]
Subcast = typing.Union[typing.Type, typing.Callable[..., _T], ma.fields.Field]
FieldType = typing.Type[ma.fields.Field]
FieldOrFactory = typing.Union[FieldType, FieldFactory]
ParserMethod = typing.Callable
ParserMethod = typing.Callable[..., _T]


_EXPANDED_VAR_PATTERN = re.compile(r"(?<!\\)\$\{([A-Za-z0-9_]+)(:-[^\}:]*)?\}")
Expand Down Expand Up @@ -75,13 +82,96 @@ class ParserConflictError(ValueError):
"""


class Field2MethodType(typing.Generic[_T]):
def __call__(
self,
name: str,
default: typing.Any = ma.missing,
subcast: typing.Optional[Subcast[_T]] = None,
*,
# Subset of relevant marshmallow.Field kwargs
load_default: typing.Any = ma.missing,
validate: typing.Optional[
typing.Union[
typing.Callable[[typing.Any], typing.Any],
typing.Iterable[typing.Callable[[typing.Any], typing.Any]],
]
] = None,
required: bool = False,
allow_none: typing.Optional[bool] = None,
error_messages: typing.Optional[typing.Dict[str, str]] = None,
metadata: typing.Optional[typing.Mapping[str, typing.Any]] = None,
) -> typing.Optional[_T]:
pass


class Field2MethodListType:
def __call__(
self,
name: str,
default: typing.Any = ma.missing,
subcast: typing.Optional[Subcast[_T]] = None,
*,
# Subset of relevant marshmallow.Field kwargs
load_default: typing.Any = ma.missing,
validate: typing.Optional[
typing.Union[
typing.Callable[[typing.Any], typing.Any],
typing.Iterable[typing.Callable[[typing.Any], typing.Any]],
]
] = None,
required: bool = False,
allow_none: typing.Optional[bool] = None,
error_messages: typing.Optional[typing.Dict[str, str]] = None,
metadata: typing.Optional[typing.Mapping[str, typing.Any]] = None,
delimiter: typing.Optional[str] = None,
) -> typing.Optional[list]:
pass


class Field2MethodDictType:
def __call__(
self,
name: str,
default: typing.Any = ma.missing,
*,
# Subset of relevant marshmallow.Field kwargs
load_default: typing.Any = ma.missing,
validate: typing.Optional[
typing.Union[
typing.Callable[[typing.Any], typing.Any],
typing.Iterable[typing.Callable[[typing.Any], typing.Any]],
]
] = None,
required: bool = False,
allow_none: typing.Optional[bool] = None,
error_messages: typing.Optional[typing.Dict[str, str]] = None,
metadata: typing.Optional[typing.Mapping[str, typing.Any]] = None,
subcast_keys: typing.Optional[Subcast[_T]],
subcast_values: typing.Optional[Subcast[_T]],
delimiter: typing.Optional[str] = None,
) -> typing.Optional[dict]:
pass


class Func2MethodEnum:
def __call__(
self,
value,
type: typing.Type[_EnumT],
default: typing.Optional[_EnumT] = None,
ignore_case: bool = False,
) -> typing.Optional[_EnumT]:
pass


def _field2method(
field_or_factory: FieldOrFactory,
method_name: str,
*,
preprocess: typing.Optional[typing.Callable] = None,
preprocess_kwarg_names: typing.Sequence[str] = tuple(),
) -> ParserMethod:
) -> typing.Any:
def method(
self: "Env",
name: str,
Expand Down Expand Up @@ -152,13 +242,13 @@ def method(
self._errors[parsed_key].extend(error.messages)
else:
self._values[parsed_key] = value
return value
return typing.cast(typing.Optional[_T], value)

method.__name__ = method_name
return method


def _func2method(func: typing.Callable, method_name: str) -> ParserMethod:
def _func2method(func: typing.Callable[..., _T], method_name: str) -> typing.Any:
def method(
self: "Env",
name: str,
Expand Down Expand Up @@ -200,14 +290,14 @@ def method(
self._errors[parsed_key].extend(messages)
else:
self._values[parsed_key] = value
return value
return typing.cast(typing.Optional[_T], value)

method.__name__ = method_name
return method


def _make_subcast_field(
subcast: typing.Optional[Subcast],
subcast: typing.Optional[Subcast[_T]],
) -> typing.Type[ma.fields.Field]:
if isinstance(subcast, type) and subcast in ma.Schema.TYPE_MAPPING:
inner_field = ma.Schema.TYPE_MAPPING[subcast]
Expand Down Expand Up @@ -274,9 +364,6 @@ def _preprocess_json(value: typing.Union[str, typing.Mapping, typing.List], **kw
raise ma.ValidationError("Not valid JSON.") from error


_EnumT = typing.TypeVar("_EnumT", bound=Enum)


def _enum_parser(value, type: typing.Type[_EnumT], ignore_case: bool = False) -> _EnumT:
invalid_exc = ma.ValidationError(f"Not a valid '{type.__name__}' enum.")

Expand All @@ -293,7 +380,7 @@ def _enum_parser(value, type: typing.Type[_EnumT], ignore_case: bool = False) ->
raise invalid_exc


def _dj_db_url_parser(value: str, **kwargs) -> dict:
def _dj_db_url_parser(value: str, **kwargs) -> DBConfig:
try:
import dj_database_url
except ImportError as error:
Expand Down Expand Up @@ -350,7 +437,7 @@ def deserialize(
data: typing.Optional[typing.Mapping] = None,
**kwargs,
) -> ParseResult:
ret = super().deserialize(value, attr, data, **kwargs)
ret = typing.cast(str, super().deserialize(value, attr, data, **kwargs))
return urlparse(ret)


Expand Down Expand Up @@ -398,18 +485,20 @@ class Env:

__call__: ParserMethod = _field2method(ma.fields.Field, "__call__")

int = _field2method(ma.fields.Int, "int")
bool = _field2method(ma.fields.Bool, "bool")
str = _field2method(ma.fields.Str, "str")
float = _field2method(ma.fields.Float, "float")
decimal = _field2method(ma.fields.Decimal, "decimal")
list = _field2method(
int: Field2MethodType["int"] = _field2method(ma.fields.Int, "int")
bool: Field2MethodType["bool"] = _field2method(ma.fields.Bool, "bool")
str: Field2MethodType["str"] = _field2method(ma.fields.Str, "str")
float: Field2MethodType["float"] = _field2method(ma.fields.Float, "float")
decimal: Field2MethodType["decimal.Decimal"] = _field2method(
ma.fields.Decimal, "decimal"
)
list: Field2MethodListType = _field2method(
_make_list_field,
"list",
preprocess=_preprocess_list,
preprocess_kwarg_names=("subcast", "delimiter"),
)
dict = _field2method(
dict: Field2MethodDictType = _field2method(
ma.fields.Dict,
"dict",
preprocess=_preprocess_dict,
Expand All @@ -421,16 +510,26 @@ class Env:
"delimiter",
),
)
json = _field2method(ma.fields.Field, "json", preprocess=_preprocess_json)
datetime = _field2method(ma.fields.DateTime, "datetime")
date = _field2method(ma.fields.Date, "date")
time = _field2method(ma.fields.Time, "time")
path = _field2method(PathField, "path")
log_level = _field2method(LogLevelField, "log_level")
timedelta = _field2method(TimeDeltaField, "timedelta")
uuid = _field2method(ma.fields.UUID, "uuid")
url = _field2method(URLField, "url")
enum = _func2method(_enum_parser, "enum")
json: Field2MethodType[typing.Union[typing.List, typing.Dict]] = _field2method(
ma.fields.Field, "json", preprocess=_preprocess_json
)
datetime: Field2MethodType["dt.datetime"] = _field2method(
ma.fields.DateTime, "datetime"
)
date: Field2MethodType["dt.date"] = _field2method(ma.fields.Date, "date")
time: Field2MethodType["dt.time"] = _field2method(ma.fields.Time, "time")
timedelta: Field2MethodType["dt.timedelta"] = _field2method(
TimeDeltaField, "timedelta"
)
path: Field2MethodType[Path] = _field2method(PathField, "path")
log_level: Field2MethodType["builtins.int"] = _field2method(
LogLevelField, "log_level"
)

uuid: Field2MethodType["uuid.UUID"] = _field2method(ma.fields.UUID, "uuid")
url: Field2MethodType[ParseResult] = _field2method(URLField, "url")

enum: Func2MethodEnum = _func2method(_enum_parser, "enum")
dj_db_url = _func2method(_dj_db_url_parser, "dj_db_url")
dj_email_url = _func2method(_dj_email_url_parser, "dj_email_url")
dj_cache_url = _func2method(_dj_cache_url_parser, "dj_cache_url")
Expand All @@ -439,7 +538,9 @@ def __init__(self, *, eager: _BoolType = True, expand_vars: _BoolType = False):
self.eager = eager
self._sealed: bool = False
self.expand_vars = expand_vars
self._fields: typing.Dict[_StrType, typing.Union[ma.fields.Field, type]] = {}
self._fields: typing.Dict[
_StrType, typing.Union[ma.fields.Field, type[ma.fields.Field]]
] = {}
self._values: typing.Dict[_StrType, typing.Any] = {}
self._errors: ErrorMapping = collections.defaultdict(list)
self._prefix: typing.Optional[_StrType] = None
Expand Down
Loading