Skip to content

Commit

Permalink
Change how defaults are processed
Browse files Browse the repository at this point in the history
  • Loading branch information
srfoster65 committed Nov 1, 2023
1 parent da1ad70 commit 2b90de5
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 17 deletions.
12 changes: 9 additions & 3 deletions src/arg_init/_arg_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class ArgInit(ABC):

def __init__(
self,
env_prefix: str = "",
priority: bool = ENV_PRIORITY,
env_prefix: str | None = None,
):
self._env_prefix = env_prefix
self._priority = priority
Expand Down Expand Up @@ -73,21 +73,27 @@ def _get_kwargs(self, arginfo, use_kwargs) -> dict:

def _make_args(self, arguments, defaults) -> None:
for name, value in arguments.items():
arg_defaults = defaults.get(name)
arg_defaults = self._get_arg_defaults(name, defaults)
env_name = self._get_env_name(name, arg_defaults)
default_value = self._get_default_value(arg_defaults)
values = Values(
arg=value, env=self._get_env_value(env_name), default=default_value
)
self._args[name] = Arg(name, env_name, values).resolve(self._priority)

def _get_arg_defaults(self, name, defaults):
for arg_defaults in defaults:
if arg_defaults.name == name:
return arg_defaults
return None

def _get_env_name(self, name, arg_defaults):
"""Determine the name to use for the env."""
if arg_defaults:
if arg_defaults.disable_env:
return None
if arg_defaults.env_name:
return arg_defaults.env_name
return arg_defaults.env_name.upper()
env_parts = [item for item in (self._env_prefix, name) if item]
return "_".join(env_parts).upper()

Expand Down
22 changes: 12 additions & 10 deletions src/arg_init/_class_arg_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from inspect import stack, getargvalues
import logging

from ._arg_init import ArgInit
from ._arg_init import ArgInit, ENV_PRIORITY


logger = logging.getLogger(__name__)
Expand All @@ -20,22 +20,24 @@ class ClassArgInit(ArgInit):

def __init__(
self,
priority=ENV_PRIORITY,
env_prefix=None,
use_kwargs=False,
set_attrs=True,
protect_attrs=True,
defaults=None,
**kwargs,
):
super().__init__(**kwargs)
super().__init__(priority, env_prefix, **kwargs)
if defaults is None:
defaults = {}
defaults = []

self._set_attrs = set_attrs
self._protect_attrs = protect_attrs
calling_stack = stack()[self.STACK_LEVEL_OFFSET]
self._init_args(calling_stack, use_kwargs, defaults)
self._class_ref = self._get_class_instance(calling_stack.frame)
self._set_class_attrs()
class_instance = self._get_class_instance(calling_stack.frame)
self._set_class_attrs(class_instance)

def _get_arguments(self, frame, use_kwargs):
"""
Expand All @@ -53,24 +55,24 @@ def _get_arguments(self, frame, use_kwargs):
args.update(self._get_kwargs(arginfo, use_kwargs))
return args

def _set_class_attrs(self):
def _set_class_attrs(self, class_ref):
"""Set attributes for the class object."""
if self._set_attrs:
logger.debug("Setting class attributes")
for arg in self._args.values():
self._set_attr(arg.name, arg.value)
self._set_attr(class_ref, arg.name, arg.value)

def _get_attr_name(self, name):
if self._protect_attrs:
return name if name.startswith("_") else "_" + name
return name

def _set_attr(self, name, value):
def _set_attr(self, class_instance, name, value):
name = self._get_attr_name(name)
if hasattr(self._class_ref, name):
if hasattr(class_instance, name):
raise AttributeError(f"Attribute already exists: {name}")
logger.debug(" %s = %s", name, value)
setattr(self._class_ref, name, value)
setattr(class_instance, name, value)

def _get_class_instance(self, frame):
"""
Expand Down
15 changes: 11 additions & 4 deletions src/arg_init/_function_arg_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from inspect import stack, getargvalues
import logging

from ._arg_init import ArgInit
from ._arg_init import ArgInit, ENV_PRIORITY


logger = logging.getLogger(__name__)
Expand All @@ -17,10 +17,17 @@ class FunctionArgInit(ArgInit):
Initialises arguments from a function.
"""

def __init__(self, use_kwargs=False, defaults=None, **kwargs):
super().__init__(**kwargs)
def __init__(
self,
priority=ENV_PRIORITY,
env_prefix=None,
use_kwargs=False,
defaults=None,
**kwargs,
):
super().__init__(priority, env_prefix, **kwargs)
if defaults is None:
defaults = {}
defaults = []
calling_stack = stack()[self.STACK_LEVEL_OFFSET]
self._init_args(calling_stack, use_kwargs, defaults)

Expand Down

0 comments on commit 2b90de5

Please sign in to comment.