diff --git a/kai/kai_config.py b/kai/kai_config.py index 04019392..3707931b 100644 --- a/kai/kai_config.py +++ b/kai/kai_config.py @@ -166,6 +166,7 @@ class KaiConfigModels(BaseModel): llama_header: Optional[bool] = Field(default=None) llm_retries: int = 5 llm_retry_delay: float = 10.0 + llm_call_budget: int = -1 # Main config diff --git a/kai/llm_interfacing/model_provider.py b/kai/llm_interfacing/model_provider.py index 11348292..55f556a1 100644 --- a/kai/llm_interfacing/model_provider.py +++ b/kai/llm_interfacing/model_provider.py @@ -22,7 +22,17 @@ LOG = get_logger(__name__) +class LLMCallBudgetReached(Exception): + def __init__( + self, message: str = "The defined LLM call budget has been reached" + ) -> None: + super().__init__(message) + + class ModelProvider: + + llm_call_budget: int = -1 + def __init__( self, config: KaiConfigModels, @@ -33,6 +43,7 @@ def __init__( self.llm_retry_delay: float = config.llm_retry_delay self.demo_mode: bool = demo_mode self.cache = cache + self.llm_call_budget = config.llm_call_budget model_class: type[BaseChatModel] defaults: dict[str, Any] @@ -193,6 +204,8 @@ def invoke( stop: Optional[list[str]] = None, **kwargs: Any, ) -> BaseMessage: + if self.llm_call_budget == 0: + raise LLMCallBudgetReached # Some fields can only be configured when the model is instantiated. # This side-steps that by creating a new instance of the model with the # configurable fields set, then invoking that new instance. @@ -206,6 +219,7 @@ def invoke( invoke_llm = self.llm if not (self.cache and cache_path_resolver): + self.llm_call_budget -= 1 return invoke_llm.invoke(input, config, stop=stop, **kwargs) cache_path = cache_path_resolver.cache_path() @@ -217,6 +231,7 @@ def invoke( if cache_entry: return cache_entry + self.llm_call_budget -= 1 response = invoke_llm.invoke(input, config, stop=stop, **kwargs) self.cache.put( diff --git a/kai/rpc_server/server.py b/kai/rpc_server/server.py index a5308b01..5fe77170 100644 --- a/kai/rpc_server/server.py +++ b/kai/rpc_server/server.py @@ -6,16 +6,7 @@ from itertools import groupby from operator import attrgetter from pathlib import Path -from typing import ( - Any, - Callable, - Generator, - Optional, - ParamSpec, - TypedDict, - TypeVar, - cast, -) +from typing import Any, Optional, TypedDict, cast from unittest.mock import MagicMock from urllib.parse import urlparse @@ -31,7 +22,7 @@ from kai.jsonrpc.models import JsonRpcError, JsonRpcErrorCode, JsonRpcId from kai.jsonrpc.util import AutoAbsPath, AutoAbsPathExists, CamelCaseBaseModel from kai.kai_config import KaiConfigModels -from kai.llm_interfacing.model_provider import ModelProvider +from kai.llm_interfacing.model_provider import LLMCallBudgetReached, ModelProvider from kai.logging.logging import TRACE, KaiLogConfig, get_logger from kai.reactive_codeplanner.agent.analyzer_fix.agent import AnalyzerAgent from kai.reactive_codeplanner.agent.dependency_agent.dependency_agent import ( @@ -121,6 +112,7 @@ def __init__(self) -> None: self.analysis_validator: Optional[AnalyzerLSPStep] = None self.task_manager: Optional[TaskManager] = None self.rcm: Optional[RepoContextManager] = None + self.model_provider: Optional[ModelProvider] = None app = KaiRpcApplication() @@ -201,6 +193,7 @@ def initialize( cache.model_id = re.sub(r"[\.:\\/]", "_", model_provider.model_id) model_provider.validate_environment() + app.model_provider = model_provider except Exception as e: server.send_response( id=id, @@ -463,7 +456,9 @@ class GetCodeplanAgentSolutionParams(BaseModel): file_path: Path incidents: list[ExtendedIncident] + # Deprecated in favor of llm_call_budget max_iterations: Optional[int] = None + llm_call_budget: Optional[int] = None max_depth: Optional[int] = None max_priority: Optional[int] = None @@ -478,6 +473,7 @@ def get_codeplan_agent_solution( id: JsonRpcId, params: GetCodeplanAgentSolutionParams, ) -> None: + def simple_chat_message(msg: str) -> None: app.log.info("simple_chat_message!") server.send_notification( @@ -491,6 +487,20 @@ def simple_chat_message(msg: str) -> None: }, ) + # This means unlimited calls, as counting down won't reach 0 + llm_call_budget: int = -1 + if params.max_iterations and not params.llm_call_budget: + llm_call_budget = params.max_iterations + elif params.llm_call_budget is not None: + llm_call_budget = llm_call_budget + + if app.model_provider is not None: + # Set LLM call budget for the new request + app.model_provider.llm_call_budget = llm_call_budget + else: + # We should never hit this branch, this is mostly to make mypy happy + raise Exception("ModelProvider not initialized") + simple_chat_message("Starting!") try: @@ -571,14 +581,10 @@ def simple_chat_message(msg: str) -> None: app.task_manager.set_seed_tasks(*seed_tasks) app.log.info( - f"Starting code plan loop with iterations: {params.max_iterations}, max depth: {params.max_depth}, and max priority: {params.max_priority}" + f"Starting code plan loop with llm call budget: {llm_call_budget}, max depth: {params.max_depth}, and max priority: {params.max_priority}" ) simple_chat_message( - f"Starting processing with iterations: {params.max_iterations}, max depth: {params.max_depth}, and max priority: {params.max_priority}" - ) - - next_task_fn = scoped_task_fn( - params.max_iterations, app.task_manager.get_next_task + f"Starting processing with llm call budget: {llm_call_budget}, max depth: {params.max_depth}, and max priority: {params.max_priority}" ) class OverallResult(TypedDict): @@ -599,7 +605,9 @@ class OverallResult(TypedDict): simple_chat_message("Running validators...") - for task in next_task_fn(params.max_priority, params.max_depth): + for task in app.task_manager.get_next_task( + params.max_priority, params.max_depth + ): app.log.debug(f"Executing task {task.__class__.__name__}: {task}") if hasattr(task, "message"): simple_chat_message( @@ -615,7 +623,16 @@ class OverallResult(TypedDict): # get the ignored tasks set pre_task_ignored_tasks = set(app.task_manager.ignored_tasks) - result = app.task_manager.execute_task(task) + try: + result = app.task_manager.execute_task(task) + except LLMCallBudgetReached: + # We can no longer solve any problems, we'll need to put the task back in the queue and return now + app.task_manager.priority_queue.push(task) + app.log.info("LLM call budget reached, ending the loop early") + simple_chat_message( + "LLM call budget has been reached. Stopping work..." + ) + break app.log.debug(f"Task {task.__class__.__name__}, result: {result}") # simple_chat_message(f"Got result! Encountered errors: {result.encountered_errors}. Modified files: {result.modified_files}.") @@ -706,30 +723,3 @@ class OverallResult(TypedDict): app.log.error(e) raise - - -P = ParamSpec("P") -R = TypeVar("R") - - -def scoped_task_fn( - max_iterations: Optional[int], next_task_fn: Callable[P, Generator[R, Any, None]] -) -> Callable[P, Generator[R, Any, None]]: - log = get_logger("fn_selection") - if max_iterations is None: - log.debug("No max_iterations, returning default get_next_task") - return next_task_fn - - def inner(*args: P.args, **kwargs: P.kwargs) -> Generator[R, Any, None]: - log.info(f"In inner {args}, {kwargs}") - generator = next_task_fn(*args, **kwargs) - for i in range(max_iterations): - try: - log.debug(f"Yielding on iteration {i}") - yield next(generator) - except StopIteration: - break - - log.debug("Returning the iteration-scoped get_next_task function") - - return inner