diff --git a/gen3discoveryai/main.py b/gen3discoveryai/main.py index af064aa1..a839b1d8 100644 --- a/gen3discoveryai/main.py +++ b/gen3discoveryai/main.py @@ -1,7 +1,7 @@ import os +import traceback from contextlib import asynccontextmanager from importlib.metadata import version -import traceback import fastapi import yaml @@ -105,19 +105,14 @@ async def lifespan(fastapi_app: fastapi.FastAPI): } config.topics[topic].update(topic_raw_cfg["metadata"]) - chain_instance = chain_factory.get( - topic_raw_cfg["topic_chain"], + _create_and_register_topic_chain( topic=topic, - metadata=config.topics[topic], - ) - config.topics[topic].update( - { - "topic_chain": chain_instance, - } + topic_raw_cfg=topic_raw_cfg, + chain_factory=chain_factory, ) + logging.info(f"Added topic `{topic}`") logging.debug(f"`{topic}` configuration: `{topic_raw_cfg}`") - except Exception as exc: logging.error( f"Unable to load `{topic}` configuration with: {topic_raw_cfg}. " @@ -135,4 +130,20 @@ async def lifespan(fastapi_app: fastapi.FastAPI): config.topics.clear() +def _create_and_register_topic_chain(topic, topic_raw_cfg, chain_factory): + """ + Small helper function to create instance of the topic chain and add to the config + """ + chain_instance = chain_factory.get( + topic_raw_cfg["topic_chain"], + topic=topic, + metadata=config.topics[topic], + ) + config.topics[topic].update( + { + "topic_chain": chain_instance, + } + ) + + app = get_app() diff --git a/tests/test_config.py b/tests/test_config.py index 56bb1af5..a8f8e2b9 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,10 +1,11 @@ import importlib import os +from unittest.mock import patch import pytest -from gen3discoveryai import config -from gen3discoveryai.main import _override_generated_openapi_spec +from gen3discoveryai import config, main +from gen3discoveryai.main import _override_generated_openapi_spec, lifespan from gen3discoveryai.topic_chains.utils import get_from_cfg_metadata @@ -21,6 +22,38 @@ def test_bad_config_metadata(): os.chdir(os.path.dirname(os.path.abspath(__file__)).rstrip("/") + "/..") +@pytest.mark.asyncio +@patch("gen3discoveryai.main._create_and_register_topic_chain") +async def test_bad_config_default_topic(create_and_register_topic_chain): + """ + Test when config loading raises an error, either it's reraised if it's the default topic + """ + create_and_register_topic_chain.side_effect = Exception("some expection") + + with pytest.raises(Exception): + async with lifespan(main.app): + pass + + +@pytest.mark.asyncio +@patch("gen3discoveryai.main._create_and_register_topic_chain") +async def test_bad_config_non_default_topic(create_and_register_topic_chain): + """ + Test when config loading raises an error, either it's reraised if it's the default topic + """ + + def _exception_if_default(*args, **kwargs): + # simulate default topic being okay, e.g. don't raise error here + if kwargs.get("topic") == "default": + pass + + create_and_register_topic_chain.side_effect = _exception_if_default + + async with lifespan(main.app): + # we don't expect an exception for non default topics + assert True + + def test_metadata_cfg_util(): """ If it exists, return it