diff --git a/jeni.py b/jeni.py index f0fb897..d92a0a1 100644 --- a/jeni.py +++ b/jeni.py @@ -70,46 +70,66 @@ def close(self): """ +class FactoryProvider(Provider): + """Adapt factory functions to the Provider interface. + + `Injector` uses this class to support registering factories. + """ + unset_error = None + + @classmethod + def bind(cls, fn): + @annotate(annotate.partial_regardless(fn)) + def init(fn): + return cls(fn) + return init + + def __init__(self, function): + self.function = function + try: + self.value = function() + except UnsetError as err: + self.unset_error = err + + def get(self, name=None): + if name is not None: + return self.function(name) + if self.unset_error is not None: + raise self.unset_error + return self.value + + class GeneratorProvider(Provider): """Manage generator lifecycle to implement Provider interface. `Injector` uses this class to support registering generators. - When used directly, note that method `init` must be called before `get`:: - - def generator(foo, bar): - yield - # continues when GeneratorProvider.close is called. - provider = GeneratorProvider(generator) - provider.init('foo', 'bar') - provider.get() """ + @classmethod + def bind(cls, fn, support_name=False): + @annotate(annotate.partial_regardless(fn)) + def init(fn): + return cls(fn, support_name=support_name) + return init + def __init__(self, function, support_name=False): """Accept generator function & whether generator supports send.""" - if not inspect.isgeneratorfunction(function): - msg = '{!r} is not a generator function' - raise TypeError(msg.format(function)) self.function = function self.support_name = support_name - self.initialized = False - def init(self, *a, **kw): - """Call function to create generator, passing arguments provided.""" - self.generator = self.function(*a, **kw) + self.generator = function() + if not inspect.isgenerator(self.generator): + msg = '{!r} is not a generator function' + raise TypeError(msg.format(function)) + try: self.init_value = next(self.generator) except StopIteration: msg = "generator didn't yield: function {!r}" raise RuntimeError(msg.format(self.function)) - else: - self.initialized = True - return self.init_value def get(self, name=None): """Get initial yield value, or result of send(name) if name given.""" - if not self.initialized: - msg = '{!r} not initialized; call `init` before `get`.' - raise RuntimeError(msg.format(self)) if name is None: return self.init_value elif not self.support_name: @@ -124,8 +144,6 @@ def get(self, name=None): def close(self): """Close the generator.""" - if not self.initialized: - raise RuntimeError('{!r} not initialized'.format(self)) if self.support_name: self.generator.close() try: @@ -343,6 +361,7 @@ def eager_partial_regardless(__fn, *a, **kw): class Injector(object): """Collects dependencies and reads annotations to inject them.""" annotator_class = Annotator + factory_provider = FactoryProvider generator_provider = GeneratorProvider re_note = re.compile(r'^(.*?)(?::(.*))?$') # annotation is 'object:name' @@ -430,18 +449,16 @@ def spam(): Injector.provider('hello', HelloProvider) """ - def decorator(fn_or_class): - if inspect.isgeneratorfunction(fn_or_class): - fn = fn_or_class - fn.support_name = name - cls.register(note, fn) - else: - provider = fn_or_class - if not hasattr(provider, 'get'): - msg = "{!r} does not meet provider interface with 'get'" - raise ValueError(msg.format(provider)) - cls.register(note, provider) - return fn_or_class + def decorator(provider): + if inspect.isgeneratorfunction(provider): + # Automatically adapt generator functions + provider = cls.generator_provider.bind( + provider, support_name=name) + return decorator(provider) + + cls.register(note, provider) + return provider + if provider is not None: decorator(provider) else: @@ -467,12 +484,14 @@ def echo(name=None): Injector.factory('echo', echo) """ + def decorator(f): + provider = cls.factory_provider.bind(f) + cls.register(note, provider) + return f + if fn is not None: - cls.register(note, fn) + decorator(fn) else: - def decorator(f): - cls.register(note, f) - return f return decorator @classmethod @@ -620,9 +639,6 @@ def close(self): if self.closed: raise RuntimeError('{!r} already closed'.format(self)) for basenote in reversed(self.get_order): - if basenote not in self.instances: - # Provider is not an instance; no close implementation. - continue # Note: Unable to apply injector on close method. self.instances[basenote].close() self.closed = True @@ -676,36 +692,26 @@ def handle_provider(self, provider_or_fn, note): self.get_order.append(basenote) return result - def _handle_provider(self, provider_or_fn, note, basenote, name): - if basenote in self.instances: - provider_or_fn = self.instances[basenote] - elif inspect.isclass(provider_or_fn): - # Inject class __init__, if annotated. - cls = provider_or_fn - if hasattr(cls, '__init__') and self.has_annotations(cls.__init__): - args, kwargs = self.prepare_callable(cls.__init__) - provider_or_fn = provider_or_fn(*args, **kwargs) + def _handle_provider(self, provider_factory, note, basenote, name): + if basenote not in self.instances: + if (isinstance(provider_factory, type) and + self.has_annotations(provider_factory.__init__)): + args, kwargs = self.prepare_callable(provider_factory.__init__) + self.instances[basenote] = provider_factory(*args, **kwargs) + else: - provider_or_fn = provider_or_fn() - self.instances[basenote] = provider_or_fn - elif inspect.isgeneratorfunction(provider_or_fn): - provider_or_fn, value = self.init_generator(provider_or_fn) - self.instances[basenote] = provider_or_fn - self.values[basenote] = value - if name is None: - return value - if hasattr(provider_or_fn, 'get'): - fn = provider_or_fn.get - else: - fn = provider_or_fn - if self.has_annotations(fn): - fn = self.partial(fn) + self.instances[basenote] = self.apply_regardless( + provider_factory) + + provider = self.instances[basenote] + get = self.partial_regardless(provider.get) + try: - if name is None: - value = fn() - self.values[basenote] = value - return value - return fn(name=name) + if name is not None: + return get(name=name) + self.values[basenote] = get() + return self.values[basenote] + except UnsetError: # Use sys.exc_info to support both Python 2 and Python 3. exc_type, exc_value, tb = sys.exc_info() @@ -737,17 +743,6 @@ def lookup(cls, basenote): return c.provider_registry[basenote] raise LookupError(repr(basenote)) - def init_generator(self, fn): - """Implementation to initialize generator providers.""" - provider = self.generator_provider(fn, support_name=fn.support_name) - if self.has_annotations(provider.function): - notes, keyword_notes = self.get_annotations(provider.function) - args, kwargs = self.prepare_notes(*notes, **keyword_notes) - value = provider.init(*args, **kwargs) - else: - value = provider.init() - return provider, value - def __enter__(self): """Support for context manager, returning self.""" return self diff --git a/test_jeni.py b/test_jeni.py index 8f1d6f1..69299ae 100644 --- a/test_jeni.py +++ b/test_jeni.py @@ -950,6 +950,7 @@ def test_a_few_calls(self): 'hello': 1, 'hello:thing': 1, 'eggs': 2, + (jeni.PARTIAL_REGARDLESS, (eggs, (), ())): 1, } self.injector.get('eggs') self.injector.get('hello') @@ -962,6 +963,7 @@ def test_many_calls(self): 'hello': 10, 'hello:thing': 15, 'eggs': 21, + (jeni.PARTIAL_REGARDLESS, (eggs, (), ())): 1, } for _ in range(10): self.injector.get('hello') @@ -992,7 +994,6 @@ def test_generator(self): def fn(): yield 42 provider = jeni.GeneratorProvider(fn) - provider.init() self.assertEqual(42, provider.get()) self.assertEqual(42, provider.get()) self.assertRaises(TypeError, provider.get, name='name') @@ -1002,17 +1003,10 @@ def fn(): "not a generator" self.assertRaises(TypeError, jeni.GeneratorProvider, fn) - def test_init_error(self): - def fn(): yield - provider = jeni.GeneratorProvider(fn) - self.assertRaises(RuntimeError, provider.get) - self.assertRaises(RuntimeError, provider.close) - def test_init_no_error(self): def fn(): yield 42 provider = jeni.GeneratorProvider(fn) - provider.init() provider.get() provider.close() @@ -1021,14 +1015,12 @@ def fn(work=False): if work: yield 'foo' self.assertEqual(['foo'], list(fn(work=True))) - provider = jeni.GeneratorProvider(fn) - self.assertRaises(RuntimeError, provider.init) + self.assertRaises(RuntimeError, jeni.GeneratorProvider, fn) def test_generator_with_broken_name_support(self): def fn(): yield 42 provider = jeni.GeneratorProvider(fn, support_name=True) - provider.init() self.assertEqual(42, provider.get()) self.assertRaises(RuntimeError, provider.get, name='name') @@ -1036,7 +1028,6 @@ def test_generator_which_keeps_yielding(self): def fn(): yield 'one'; yield 'two' provider = jeni.GeneratorProvider(fn) - provider.init() self.assertRaises(RuntimeError, provider.close) @@ -1104,19 +1095,11 @@ def setUp(self): self.Injector = self.TestInjector self.injector = self.Injector() - def decorate(self): - @self.Injector.provider('no get') - class BadProvider(object): - pass - def subclass(self): class BadSubclass(jeni.Provider): pass return BadSubclass - def test_interface_check(self): - self.assertRaises(ValueError, self.decorate) - def test_subclass_meta(self): cls = self.subclass() self.assertRaises(TypeError, cls)