Skip to content

Commit

Permalink
Implement strategy Pattern for approach selection and use dynamic app…
Browse files Browse the repository at this point in the history
…roach discovery (#385)

Co-authored-by: = Enea_Gore <[email protected]>
Co-authored-by: Felix T.J. Dietrich <[email protected]>
  • Loading branch information
3 people authored Jan 20, 2025
1 parent 863ba88 commit 23ce7e9
Show file tree
Hide file tree
Showing 12 changed files with 700 additions and 428 deletions.
33 changes: 33 additions & 0 deletions athena/athena/approach_discovery/discover_approaches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import importlib
import pkgutil
import inspect

def discover_approach_configs(base_package, base_class=None):
"""
Discover and return classes within the specified package.
Args:
base_package (str): The package to search.
base_class (type, optional): A base class to filter discovered classes.
Only subclasses of this base class will be included.
Returns:
dict: A dictionary mapping class names to class objects.
"""
classes = {}
package = importlib.import_module(base_package)

def recursive_import(package_name):
package = importlib.import_module(package_name)
for _, module_name, is_pkg in pkgutil.iter_modules(package.__path__):
full_module_name = f"{package_name}.{module_name}"
if is_pkg:
recursive_import(full_module_name)
else:
module = importlib.import_module(full_module_name)
for name, obj in inspect.getmembers(module, inspect.isclass):
if base_class is None or (issubclass(obj, base_class) and obj is not base_class):
classes[name] = obj

recursive_import(base_package)
return classes
74 changes: 74 additions & 0 deletions athena/athena/approach_discovery/strategy_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from athena.approach_discovery.discover_approaches import discover_approach_configs
from typing import List, Callable

class SuggestionStrategyFactory:
"""
A factory class for discovering, initializing, and retrieving suggestion strategies.
The `SuggestionStrategyFactory` dynamically loads strategy classes, associates them with
specific configuration types, and provides instances of the strategies based on the given
configuration. It supports modular discovery and initialization to handle a variety of
strategies seamlessly.
Attributes:
_strategies (dict): A dictionary mapping configuration class names to their corresponding strategy classes.
"""
_strategies: dict[str, Callable] = {}
def __init__(self, base_package: str, base_class: type):
"""
Initialize the factory by providing the base package and base class for discovering strategies.
Args:
base_package (str): The base package to search for strategies and configurations.
base_class (type): The base class for configurations that strategies will be associated with.
"""
self.base_package = base_package
self.base_class = base_class

# Initialize strategies on object creation
self.initialize_strategies()

def initialize_strategies(self):
"""
Initialize the factory by associating configuration types with their corresponding strategies.
This method uses the `discover_classes` function to identify configuration classes and their
associated strategies. The mappings are stored in the `_strategies` dictionary for later retrieval.
Args:
base_package (str): The base package to search for strategies and configurations.
Defaults to "module_text_llm".
"""
if not SuggestionStrategyFactory._strategies:
configs = discover_approach_configs(self.base_package, base_class=self.base_class)
# strategies = discover_approach_configs(base_package)

for config_name, config_class in configs.items():
strategy_class = configs.get(config_name)
if strategy_class:
SuggestionStrategyFactory._strategies[config_name] = strategy_class


def get_strategy(self, config):
"""
Retrieve an instance of the strategy corresponding to the given configuration.
If the strategies have not been initialized, this method will initialize them first.
The method then matches the type of the provided configuration with the corresponding
strategy class and returns an instance of it.
Args:
config (object): The configuration object for which the strategy is required.
Returns:
object: An instance of the strategy class associated with the given configuration.
Raises:
ValueError: If no strategy is found for the given configuration type.
"""

config_type = type(config).__name__
strategy_class = SuggestionStrategyFactory._strategies.get(config_type)
if not strategy_class:
raise ValueError(f"No strategy found for config type: {config_type}")
return strategy_class()
4 changes: 4 additions & 0 deletions modules/text/module_text_llm/module_text_llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import dotenv
from athena.approach_discovery.strategy_factory import SuggestionStrategyFactory

dotenv.load_dotenv(override=True)

def get_strategy_factory(base_class):
return SuggestionStrategyFactory("module_text_llm", base_class)
4 changes: 2 additions & 2 deletions modules/text/module_text_llm/module_text_llm/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def process_incoming_feedback(exercise: Exercise, submission: Submission, feedba

@feedback_provider
async def suggest_feedback(exercise: Exercise, submission: Submission, is_graded: bool, module_config: Configuration) -> List[Feedback]:
logger.info("suggest_feedback: %s suggestions for submission %d of exercise %d were requested",
"Graded" if is_graded else "Non-graded", submission.id, exercise.id)
logger.info("suggest_feedback: %s suggestions for submission %d of exercise %d were requested, with approach: %s",
"Graded" if is_graded else "Non-graded", submission.id, exercise.id, module_config.approach.__class__.__name__)
return await generate_suggestions(exercise, submission, module_config.approach, module_config.debug, is_graded)


Expand Down
15 changes: 8 additions & 7 deletions modules/text/module_text_llm/module_text_llm/approach_config.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from abc import ABC
from pydantic import BaseModel, Field
from llm_core.models import ModelConfigType, DefaultModelConfig
from enum import Enum

class ApproachType(str, Enum):
basic = "BasicApproach"
chain_of_thought = "ChainOfThought"
from abc import ABC, abstractmethod
from athena.text import Exercise, Submission

class ApproachConfig(BaseModel, ABC):
max_input_tokens: int = Field(default=3000, description="Maximum number of tokens in the input prompt.")
model: ModelConfigType = Field(default=DefaultModelConfig())
type: str = Field(..., description="The type of approach config")

@abstractmethod
async def generate_suggestions(self, exercise: Exercise, submission: Submission, config, *, debug: bool, is_graded: bool):
pass

class Config:
use_enum_values = True
use_enum_values = True

Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
from typing import List
from athena.text import Exercise, Submission, Feedback
from module_text_llm.basic_approach import BasicApproachConfig
from module_text_llm.chain_of_thought_approach import ChainOfThoughtConfig
from module_text_llm.approach_config import ApproachConfig

from module_text_llm.basic_approach.generate_suggestions import generate_suggestions as generate_suggestions_basic
from module_text_llm.chain_of_thought_approach.generate_suggestions import generate_suggestions as generate_cot_suggestions
from module_text_llm import get_strategy_factory

async def generate_suggestions(exercise: Exercise, submission: Submission, config: ApproachConfig, debug: bool, is_graded: bool) -> List[Feedback]:
if isinstance(config, BasicApproachConfig):
return await generate_suggestions_basic(exercise, submission, config, debug, is_graded)
if isinstance(config, ChainOfThoughtConfig):
return await generate_cot_suggestions(exercise, submission, config, debug, is_graded)
raise ValueError("Unsupported config type provided.")
strategy_factory = get_strategy_factory(ApproachConfig)
strategy = strategy_factory.get_strategy(config)
return await strategy.generate_suggestions(exercise, submission, config, debug = debug, is_graded = is_graded)
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from module_text_llm.approach_config import ApproachConfig
from pydantic import Field
from typing import Literal


from module_text_llm.basic_approach.generate_suggestions import generate_suggestions
from athena.text import Exercise, Submission
from module_text_llm.basic_approach.prompt_generate_suggestions import GenerateSuggestionsPrompt

class BasicApproachConfig(ApproachConfig):
type: Literal['basic'] = 'basic'
generate_suggestions_prompt: GenerateSuggestionsPrompt = Field(default=GenerateSuggestionsPrompt())

async def generate_suggestions(self, exercise: Exercise, submission: Submission, config, *, debug: bool, is_graded: bool):
return await generate_suggestions(exercise,submission,config,debug,is_graded)

Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List
from module_text_llm.approach_config import ApproachConfig

from athena import emit_meta
from athena.text import Exercise, Submission, Feedback
Expand All @@ -10,11 +11,11 @@
)
from llm_core.utils.predict_and_parse import predict_and_parse

from module_text_llm.config import BasicApproachConfig
# from module_text_llm.config import BasicApproachConfig
from module_text_llm.helpers.utils import add_sentence_numbers, get_index_range_from_line_range, format_grading_instructions
from module_text_llm.basic_approach.prompt_generate_suggestions import AssessmentModel

async def generate_suggestions(exercise: Exercise, submission: Submission, config: BasicApproachConfig, debug: bool, is_graded: bool) -> List[Feedback]:
async def generate_suggestions(exercise: Exercise, submission: Submission, config: ApproachConfig, debug: bool, is_graded: bool) -> List[Feedback]:
model = config.model.get_model() # type: ignore[attr-defined]
prompt_input = {
"max_points": exercise.max_points,
Expand Down Expand Up @@ -89,8 +90,8 @@ async def generate_suggestions(exercise: Exercise, submission: Submission, confi
index_start=index_start,
index_end=index_end,
credits=feedback.credits,
structured_grading_instruction_id=grading_instruction_id,
is_graded=is_graded,
structured_grading_instruction_id=grading_instruction_id,
meta={}
))

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from pydantic import Field
from typing import Literal
from athena.text import Exercise, Submission

from module_text_llm.approach_config import ApproachConfig
from module_text_llm.chain_of_thought_approach.prompt_generate_feedback import CoTGenerateSuggestionsPrompt
from module_text_llm.chain_of_thought_approach.prompt_thinking import ThinkingPrompt
from module_text_llm.chain_of_thought_approach.generate_suggestions import generate_suggestions

class ChainOfThoughtConfig(ApproachConfig):
type: Literal['chain_of_thought'] = 'chain_of_thought'
thinking_prompt: ThinkingPrompt = Field(default=ThinkingPrompt())
generate_suggestions_prompt: CoTGenerateSuggestionsPrompt = Field(default=CoTGenerateSuggestionsPrompt())

async def generate_suggestions(self, exercise: Exercise, submission: Submission, config, *, debug: bool, is_graded: bool):
return await generate_suggestions(exercise,submission,config,debug,is_graded)

Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from athena.text import Exercise, Submission, Feedback
from athena.logger import logger

from module_text_llm.chain_of_thought_approach import ChainOfThoughtConfig
from module_text_llm.approach_config import ApproachConfig

from llm_core.utils.llm_utils import (
get_chat_prompt_with_formatting_instructions,
Expand All @@ -17,7 +17,7 @@
from module_text_llm.chain_of_thought_approach.prompt_generate_feedback import AssessmentModel


async def generate_suggestions(exercise: Exercise, submission: Submission, config: ChainOfThoughtConfig, debug: bool, is_graded: bool) -> List[Feedback]:
async def generate_suggestions(exercise: Exercise, submission: Submission, config: ApproachConfig, debug: bool, is_graded: bool) -> List[Feedback]:
model = config.model.get_model() # type: ignore[attr-defined]

prompt_input = {
Expand Down Expand Up @@ -118,8 +118,8 @@ async def generate_suggestions(exercise: Exercise, submission: Submission, confi
index_start=index_start,
index_end=index_end,
credits=feedback.credits,
structured_grading_instruction_id=grading_instruction_id,
is_graded=is_graded,
structured_grading_instruction_id=grading_instruction_id,
meta={}
))

Expand Down
Loading

0 comments on commit 23ce7e9

Please sign in to comment.