-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement strategy Pattern for approach selection and use dynamic app…
…roach discovery (#385) Co-authored-by: = Enea_Gore <[email protected]> Co-authored-by: Felix T.J. Dietrich <[email protected]>
- Loading branch information
1 parent
863ba88
commit 23ce7e9
Showing
12 changed files
with
700 additions
and
428 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import importlib | ||
import pkgutil | ||
import inspect | ||
|
||
def discover_approach_configs(base_package, base_class=None): | ||
""" | ||
Discover and return classes within the specified package. | ||
Args: | ||
base_package (str): The package to search. | ||
base_class (type, optional): A base class to filter discovered classes. | ||
Only subclasses of this base class will be included. | ||
Returns: | ||
dict: A dictionary mapping class names to class objects. | ||
""" | ||
classes = {} | ||
package = importlib.import_module(base_package) | ||
|
||
def recursive_import(package_name): | ||
package = importlib.import_module(package_name) | ||
for _, module_name, is_pkg in pkgutil.iter_modules(package.__path__): | ||
full_module_name = f"{package_name}.{module_name}" | ||
if is_pkg: | ||
recursive_import(full_module_name) | ||
else: | ||
module = importlib.import_module(full_module_name) | ||
for name, obj in inspect.getmembers(module, inspect.isclass): | ||
if base_class is None or (issubclass(obj, base_class) and obj is not base_class): | ||
classes[name] = obj | ||
|
||
recursive_import(base_package) | ||
return classes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from athena.approach_discovery.discover_approaches import discover_approach_configs | ||
from typing import List, Callable | ||
|
||
class SuggestionStrategyFactory: | ||
""" | ||
A factory class for discovering, initializing, and retrieving suggestion strategies. | ||
The `SuggestionStrategyFactory` dynamically loads strategy classes, associates them with | ||
specific configuration types, and provides instances of the strategies based on the given | ||
configuration. It supports modular discovery and initialization to handle a variety of | ||
strategies seamlessly. | ||
Attributes: | ||
_strategies (dict): A dictionary mapping configuration class names to their corresponding strategy classes. | ||
""" | ||
_strategies: dict[str, Callable] = {} | ||
def __init__(self, base_package: str, base_class: type): | ||
""" | ||
Initialize the factory by providing the base package and base class for discovering strategies. | ||
Args: | ||
base_package (str): The base package to search for strategies and configurations. | ||
base_class (type): The base class for configurations that strategies will be associated with. | ||
""" | ||
self.base_package = base_package | ||
self.base_class = base_class | ||
|
||
# Initialize strategies on object creation | ||
self.initialize_strategies() | ||
|
||
def initialize_strategies(self): | ||
""" | ||
Initialize the factory by associating configuration types with their corresponding strategies. | ||
This method uses the `discover_classes` function to identify configuration classes and their | ||
associated strategies. The mappings are stored in the `_strategies` dictionary for later retrieval. | ||
Args: | ||
base_package (str): The base package to search for strategies and configurations. | ||
Defaults to "module_text_llm". | ||
""" | ||
if not SuggestionStrategyFactory._strategies: | ||
configs = discover_approach_configs(self.base_package, base_class=self.base_class) | ||
# strategies = discover_approach_configs(base_package) | ||
|
||
for config_name, config_class in configs.items(): | ||
strategy_class = configs.get(config_name) | ||
if strategy_class: | ||
SuggestionStrategyFactory._strategies[config_name] = strategy_class | ||
|
||
|
||
def get_strategy(self, config): | ||
""" | ||
Retrieve an instance of the strategy corresponding to the given configuration. | ||
If the strategies have not been initialized, this method will initialize them first. | ||
The method then matches the type of the provided configuration with the corresponding | ||
strategy class and returns an instance of it. | ||
Args: | ||
config (object): The configuration object for which the strategy is required. | ||
Returns: | ||
object: An instance of the strategy class associated with the given configuration. | ||
Raises: | ||
ValueError: If no strategy is found for the given configuration type. | ||
""" | ||
|
||
config_type = type(config).__name__ | ||
strategy_class = SuggestionStrategyFactory._strategies.get(config_type) | ||
if not strategy_class: | ||
raise ValueError(f"No strategy found for config type: {config_type}") | ||
return strategy_class() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
import dotenv | ||
from athena.approach_discovery.strategy_factory import SuggestionStrategyFactory | ||
|
||
dotenv.load_dotenv(override=True) | ||
|
||
def get_strategy_factory(base_class): | ||
return SuggestionStrategyFactory("module_text_llm", base_class) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
15 changes: 8 additions & 7 deletions
15
modules/text/module_text_llm/module_text_llm/approach_config.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,17 @@ | ||
from abc import ABC | ||
from pydantic import BaseModel, Field | ||
from llm_core.models import ModelConfigType, DefaultModelConfig | ||
from enum import Enum | ||
|
||
class ApproachType(str, Enum): | ||
basic = "BasicApproach" | ||
chain_of_thought = "ChainOfThought" | ||
from abc import ABC, abstractmethod | ||
from athena.text import Exercise, Submission | ||
|
||
class ApproachConfig(BaseModel, ABC): | ||
max_input_tokens: int = Field(default=3000, description="Maximum number of tokens in the input prompt.") | ||
model: ModelConfigType = Field(default=DefaultModelConfig()) | ||
type: str = Field(..., description="The type of approach config") | ||
|
||
@abstractmethod | ||
async def generate_suggestions(self, exercise: Exercise, submission: Submission, config, *, debug: bool, is_graded: bool): | ||
pass | ||
|
||
class Config: | ||
use_enum_values = True | ||
use_enum_values = True | ||
|
14 changes: 4 additions & 10 deletions
14
modules/text/module_text_llm/module_text_llm/approach_controller.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,9 @@ | ||
from typing import List | ||
from athena.text import Exercise, Submission, Feedback | ||
from module_text_llm.basic_approach import BasicApproachConfig | ||
from module_text_llm.chain_of_thought_approach import ChainOfThoughtConfig | ||
from module_text_llm.approach_config import ApproachConfig | ||
|
||
from module_text_llm.basic_approach.generate_suggestions import generate_suggestions as generate_suggestions_basic | ||
from module_text_llm.chain_of_thought_approach.generate_suggestions import generate_suggestions as generate_cot_suggestions | ||
from module_text_llm import get_strategy_factory | ||
|
||
async def generate_suggestions(exercise: Exercise, submission: Submission, config: ApproachConfig, debug: bool, is_graded: bool) -> List[Feedback]: | ||
if isinstance(config, BasicApproachConfig): | ||
return await generate_suggestions_basic(exercise, submission, config, debug, is_graded) | ||
if isinstance(config, ChainOfThoughtConfig): | ||
return await generate_cot_suggestions(exercise, submission, config, debug, is_graded) | ||
raise ValueError("Unsupported config type provided.") | ||
strategy_factory = get_strategy_factory(ApproachConfig) | ||
strategy = strategy_factory.get_strategy(config) | ||
return await strategy.generate_suggestions(exercise, submission, config, debug = debug, is_graded = is_graded) |
7 changes: 5 additions & 2 deletions
7
modules/text/module_text_llm/module_text_llm/basic_approach/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,14 @@ | ||
from module_text_llm.approach_config import ApproachConfig | ||
from pydantic import Field | ||
from typing import Literal | ||
|
||
|
||
from module_text_llm.basic_approach.generate_suggestions import generate_suggestions | ||
from athena.text import Exercise, Submission | ||
from module_text_llm.basic_approach.prompt_generate_suggestions import GenerateSuggestionsPrompt | ||
|
||
class BasicApproachConfig(ApproachConfig): | ||
type: Literal['basic'] = 'basic' | ||
generate_suggestions_prompt: GenerateSuggestionsPrompt = Field(default=GenerateSuggestionsPrompt()) | ||
|
||
async def generate_suggestions(self, exercise: Exercise, submission: Submission, config, *, debug: bool, is_graded: bool): | ||
return await generate_suggestions(exercise,submission,config,debug,is_graded) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
5 changes: 5 additions & 0 deletions
5
modules/text/module_text_llm/module_text_llm/chain_of_thought_approach/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,17 @@ | ||
from pydantic import Field | ||
from typing import Literal | ||
from athena.text import Exercise, Submission | ||
|
||
from module_text_llm.approach_config import ApproachConfig | ||
from module_text_llm.chain_of_thought_approach.prompt_generate_feedback import CoTGenerateSuggestionsPrompt | ||
from module_text_llm.chain_of_thought_approach.prompt_thinking import ThinkingPrompt | ||
from module_text_llm.chain_of_thought_approach.generate_suggestions import generate_suggestions | ||
|
||
class ChainOfThoughtConfig(ApproachConfig): | ||
type: Literal['chain_of_thought'] = 'chain_of_thought' | ||
thinking_prompt: ThinkingPrompt = Field(default=ThinkingPrompt()) | ||
generate_suggestions_prompt: CoTGenerateSuggestionsPrompt = Field(default=CoTGenerateSuggestionsPrompt()) | ||
|
||
async def generate_suggestions(self, exercise: Exercise, submission: Submission, config, *, debug: bool, is_graded: bool): | ||
return await generate_suggestions(exercise,submission,config,debug,is_graded) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.