diff --git a/src/environs/__init__.py b/src/environs/__init__.py index a8e6edc..bf5f5e1 100644 --- a/src/environs/__init__.py +++ b/src/environs/__init__.py @@ -1,5 +1,8 @@ +import builtins import collections import contextlib +import datetime as dt +import decimal import functools import inspect import json as pyjson @@ -7,6 +10,7 @@ import os import re import typing +import uuid from collections.abc import Mapping from datetime import timedelta from enum import Enum @@ -14,6 +18,7 @@ from urllib.parse import ParseResult, urlparse import marshmallow as ma +from dj_database_url import DBConfig from dotenv.main import _walk_to_root, load_dotenv __all__ = ["EnvError", "Env"] @@ -21,6 +26,8 @@ _T = typing.TypeVar("_T") _StrType = str _BoolType = bool +_EnumT = typing.TypeVar("_EnumT", bound=Enum) + ErrorMapping = typing.Mapping[str, typing.List[str]] ErrorList = typing.List[str] @@ -28,7 +35,7 @@ Subcast = typing.Union[typing.Type, typing.Callable[..., _T], ma.fields.Field] FieldType = typing.Type[ma.fields.Field] FieldOrFactory = typing.Union[FieldType, FieldFactory] -ParserMethod = typing.Callable +ParserMethod = typing.Callable[..., _T] _EXPANDED_VAR_PATTERN = re.compile(r"(? typing.Optional[_T]: + pass + + +class Field2MethodListType: + def __call__( + self, + name: str, + default: typing.Any = ma.missing, + subcast: typing.Optional[Subcast[_T]] = None, + *, + # Subset of relevant marshmallow.Field kwargs + load_default: typing.Any = ma.missing, + validate: typing.Optional[ + typing.Union[ + typing.Callable[[typing.Any], typing.Any], + typing.Iterable[typing.Callable[[typing.Any], typing.Any]], + ] + ] = None, + required: bool = False, + allow_none: typing.Optional[bool] = None, + error_messages: typing.Optional[typing.Dict[str, str]] = None, + metadata: typing.Optional[typing.Mapping[str, typing.Any]] = None, + delimiter: typing.Optional[str] = None, + ) -> typing.Optional[list]: + pass + + +class Field2MethodDictType: + def __call__( + self, + name: str, + default: typing.Any = ma.missing, + *, + # Subset of relevant marshmallow.Field kwargs + load_default: typing.Any = ma.missing, + validate: typing.Optional[ + typing.Union[ + typing.Callable[[typing.Any], typing.Any], + typing.Iterable[typing.Callable[[typing.Any], typing.Any]], + ] + ] = None, + required: bool = False, + allow_none: typing.Optional[bool] = None, + error_messages: typing.Optional[typing.Dict[str, str]] = None, + metadata: typing.Optional[typing.Mapping[str, typing.Any]] = None, + subcast_keys: typing.Optional[Subcast[_T]], + subcast_values: typing.Optional[Subcast[_T]], + delimiter: typing.Optional[str] = None, + ) -> typing.Optional[dict]: + pass + + +class Func2MethodEnum: + def __call__( + self, + value, + type: typing.Type[_EnumT], + default: typing.Optional[_EnumT] = None, + ignore_case: bool = False, + ) -> typing.Optional[_EnumT]: + pass + + def _field2method( field_or_factory: FieldOrFactory, method_name: str, *, preprocess: typing.Optional[typing.Callable] = None, preprocess_kwarg_names: typing.Sequence[str] = tuple(), -) -> ParserMethod: +) -> typing.Any: def method( self: "Env", name: str, @@ -152,13 +242,13 @@ def method( self._errors[parsed_key].extend(error.messages) else: self._values[parsed_key] = value - return value + return typing.cast(typing.Optional[_T], value) method.__name__ = method_name return method -def _func2method(func: typing.Callable, method_name: str) -> ParserMethod: +def _func2method(func: typing.Callable[..., _T], method_name: str) -> typing.Any: def method( self: "Env", name: str, @@ -200,14 +290,14 @@ def method( self._errors[parsed_key].extend(messages) else: self._values[parsed_key] = value - return value + return typing.cast(typing.Optional[_T], value) method.__name__ = method_name return method def _make_subcast_field( - subcast: typing.Optional[Subcast], + subcast: typing.Optional[Subcast[_T]], ) -> typing.Type[ma.fields.Field]: if isinstance(subcast, type) and subcast in ma.Schema.TYPE_MAPPING: inner_field = ma.Schema.TYPE_MAPPING[subcast] @@ -274,9 +364,6 @@ def _preprocess_json(value: typing.Union[str, typing.Mapping, typing.List], **kw raise ma.ValidationError("Not valid JSON.") from error -_EnumT = typing.TypeVar("_EnumT", bound=Enum) - - def _enum_parser(value, type: typing.Type[_EnumT], ignore_case: bool = False) -> _EnumT: invalid_exc = ma.ValidationError(f"Not a valid '{type.__name__}' enum.") @@ -293,7 +380,7 @@ def _enum_parser(value, type: typing.Type[_EnumT], ignore_case: bool = False) -> raise invalid_exc -def _dj_db_url_parser(value: str, **kwargs) -> dict: +def _dj_db_url_parser(value: str, **kwargs) -> DBConfig: try: import dj_database_url except ImportError as error: @@ -350,7 +437,7 @@ def deserialize( data: typing.Optional[typing.Mapping] = None, **kwargs, ) -> ParseResult: - ret = super().deserialize(value, attr, data, **kwargs) + ret = typing.cast(str, super().deserialize(value, attr, data, **kwargs)) return urlparse(ret) @@ -398,18 +485,20 @@ class Env: __call__: ParserMethod = _field2method(ma.fields.Field, "__call__") - int = _field2method(ma.fields.Int, "int") - bool = _field2method(ma.fields.Bool, "bool") - str = _field2method(ma.fields.Str, "str") - float = _field2method(ma.fields.Float, "float") - decimal = _field2method(ma.fields.Decimal, "decimal") - list = _field2method( + int: Field2MethodType["int"] = _field2method(ma.fields.Int, "int") + bool: Field2MethodType["bool"] = _field2method(ma.fields.Bool, "bool") + str: Field2MethodType["str"] = _field2method(ma.fields.Str, "str") + float: Field2MethodType["float"] = _field2method(ma.fields.Float, "float") + decimal: Field2MethodType["decimal.Decimal"] = _field2method( + ma.fields.Decimal, "decimal" + ) + list: Field2MethodListType = _field2method( _make_list_field, "list", preprocess=_preprocess_list, preprocess_kwarg_names=("subcast", "delimiter"), ) - dict = _field2method( + dict: Field2MethodDictType = _field2method( ma.fields.Dict, "dict", preprocess=_preprocess_dict, @@ -421,16 +510,26 @@ class Env: "delimiter", ), ) - json = _field2method(ma.fields.Field, "json", preprocess=_preprocess_json) - datetime = _field2method(ma.fields.DateTime, "datetime") - date = _field2method(ma.fields.Date, "date") - time = _field2method(ma.fields.Time, "time") - path = _field2method(PathField, "path") - log_level = _field2method(LogLevelField, "log_level") - timedelta = _field2method(TimeDeltaField, "timedelta") - uuid = _field2method(ma.fields.UUID, "uuid") - url = _field2method(URLField, "url") - enum = _func2method(_enum_parser, "enum") + json: Field2MethodType[typing.Union[typing.List, typing.Dict]] = _field2method( + ma.fields.Field, "json", preprocess=_preprocess_json + ) + datetime: Field2MethodType["dt.datetime"] = _field2method( + ma.fields.DateTime, "datetime" + ) + date: Field2MethodType["dt.date"] = _field2method(ma.fields.Date, "date") + time: Field2MethodType["dt.time"] = _field2method(ma.fields.Time, "time") + timedelta: Field2MethodType["dt.timedelta"] = _field2method( + TimeDeltaField, "timedelta" + ) + path: Field2MethodType[Path] = _field2method(PathField, "path") + log_level: Field2MethodType["builtins.int"] = _field2method( + LogLevelField, "log_level" + ) + + uuid: Field2MethodType["uuid.UUID"] = _field2method(ma.fields.UUID, "uuid") + url: Field2MethodType[ParseResult] = _field2method(URLField, "url") + + enum: Func2MethodEnum = _func2method(_enum_parser, "enum") dj_db_url = _func2method(_dj_db_url_parser, "dj_db_url") dj_email_url = _func2method(_dj_email_url_parser, "dj_email_url") dj_cache_url = _func2method(_dj_cache_url_parser, "dj_cache_url") @@ -439,7 +538,9 @@ def __init__(self, *, eager: _BoolType = True, expand_vars: _BoolType = False): self.eager = eager self._sealed: bool = False self.expand_vars = expand_vars - self._fields: typing.Dict[_StrType, typing.Union[ma.fields.Field, type]] = {} + self._fields: typing.Dict[ + _StrType, typing.Union[ma.fields.Field, type[ma.fields.Field]] + ] = {} self._values: typing.Dict[_StrType, typing.Any] = {} self._errors: ErrorMapping = collections.defaultdict(list) self._prefix: typing.Optional[_StrType] = None