Skip to content

Commit

Permalink
Replace max_iterations with max_llm_calls
Browse files Browse the repository at this point in the history
- Add a counter to the model provider to optionally limit the number of
  requests that can be made.
- Remove all references to max_iterations (except for in the rpc_server
  params, we can remove that in a few releases)

Signed-off-by: Fabian von Feilitzsch <[email protected]>
  • Loading branch information
fabianvf committed Feb 6, 2025
1 parent 7cba2a0 commit 3e25470
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 46 deletions.
1 change: 1 addition & 0 deletions kai/kai_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ class KaiConfigModels(BaseModel):
llama_header: Optional[bool] = Field(default=None)
llm_retries: int = 5
llm_retry_delay: float = 10.0
llm_call_budget: int = -1


# Main config
Expand Down
15 changes: 15 additions & 0 deletions kai/llm_interfacing/model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,17 @@
LOG = get_logger(__name__)


class LLMCallBudgetReached(Exception):
def __init__(
self, message: str = "The defined LLM call budget has been reached"
) -> None:
super().__init__(message)


class ModelProvider:

llm_call_budget: int = -1

def __init__(
self,
config: KaiConfigModels,
Expand All @@ -33,6 +43,7 @@ def __init__(
self.llm_retry_delay: float = config.llm_retry_delay
self.demo_mode: bool = demo_mode
self.cache = cache
self.llm_call_budget = config.llm_call_budget

model_class: type[BaseChatModel]
defaults: dict[str, Any]
Expand Down Expand Up @@ -193,6 +204,8 @@ def invoke(
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> BaseMessage:
if self.llm_call_budget == 0:
raise LLMCallBudgetReached
# Some fields can only be configured when the model is instantiated.
# This side-steps that by creating a new instance of the model with the
# configurable fields set, then invoking that new instance.
Expand All @@ -206,6 +219,7 @@ def invoke(
invoke_llm = self.llm

if not (self.cache and cache_path_resolver):
self.llm_call_budget -= 1
return invoke_llm.invoke(input, config, stop=stop, **kwargs)

cache_path = cache_path_resolver.cache_path()
Expand All @@ -217,6 +231,7 @@ def invoke(
if cache_entry:
return cache_entry

self.llm_call_budget -= 1
response = invoke_llm.invoke(input, config, stop=stop, **kwargs)

self.cache.put(
Expand Down
82 changes: 36 additions & 46 deletions kai/rpc_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,7 @@
from itertools import groupby
from operator import attrgetter
from pathlib import Path
from typing import (
Any,
Callable,
Generator,
Optional,
ParamSpec,
TypedDict,
TypeVar,
cast,
)
from typing import Any, Optional, TypedDict, cast
from unittest.mock import MagicMock
from urllib.parse import urlparse

Expand All @@ -31,7 +22,7 @@
from kai.jsonrpc.models import JsonRpcError, JsonRpcErrorCode, JsonRpcId
from kai.jsonrpc.util import AutoAbsPath, AutoAbsPathExists, CamelCaseBaseModel
from kai.kai_config import KaiConfigModels
from kai.llm_interfacing.model_provider import ModelProvider
from kai.llm_interfacing.model_provider import LLMCallBudgetReached, ModelProvider
from kai.logging.logging import TRACE, KaiLogConfig, get_logger
from kai.reactive_codeplanner.agent.analyzer_fix.agent import AnalyzerAgent
from kai.reactive_codeplanner.agent.dependency_agent.dependency_agent import (
Expand Down Expand Up @@ -121,6 +112,7 @@ def __init__(self) -> None:
self.analysis_validator: Optional[AnalyzerLSPStep] = None
self.task_manager: Optional[TaskManager] = None
self.rcm: Optional[RepoContextManager] = None
self.model_provider: Optional[ModelProvider] = None


app = KaiRpcApplication()
Expand Down Expand Up @@ -201,6 +193,7 @@ def initialize(
cache.model_id = re.sub(r"[\.:\\/]", "_", model_provider.model_id)

model_provider.validate_environment()
app.model_provider = model_provider
except Exception as e:
server.send_response(
id=id,
Expand Down Expand Up @@ -463,7 +456,9 @@ class GetCodeplanAgentSolutionParams(BaseModel):
file_path: Path
incidents: list[ExtendedIncident]

# Deprecated in favor of llm_call_budget
max_iterations: Optional[int] = None
llm_call_budget: Optional[int] = None
max_depth: Optional[int] = None
max_priority: Optional[int] = None

Expand All @@ -478,6 +473,7 @@ def get_codeplan_agent_solution(
id: JsonRpcId,
params: GetCodeplanAgentSolutionParams,
) -> None:

def simple_chat_message(msg: str) -> None:
app.log.info("simple_chat_message!")
server.send_notification(
Expand All @@ -491,6 +487,20 @@ def simple_chat_message(msg: str) -> None:
},
)

# This means unlimited calls, as counting down won't reach 0
llm_call_budget: int = -1
if params.max_iterations and not params.llm_call_budget:
llm_call_budget = params.max_iterations
elif params.llm_call_budget is not None:
llm_call_budget = llm_call_budget

if app.model_provider is not None:
# Set LLM call budget for the new request
app.model_provider.llm_call_budget = llm_call_budget
else:
# We should never hit this branch, this is mostly to make mypy happy
raise Exception("ModelProvider not initialized")

simple_chat_message("Starting!")

try:
Expand Down Expand Up @@ -571,14 +581,10 @@ def simple_chat_message(msg: str) -> None:
app.task_manager.set_seed_tasks(*seed_tasks)

app.log.info(
f"Starting code plan loop with iterations: {params.max_iterations}, max depth: {params.max_depth}, and max priority: {params.max_priority}"
f"Starting code plan loop with llm call budget: {llm_call_budget}, max depth: {params.max_depth}, and max priority: {params.max_priority}"
)
simple_chat_message(
f"Starting processing with iterations: {params.max_iterations}, max depth: {params.max_depth}, and max priority: {params.max_priority}"
)

next_task_fn = scoped_task_fn(
params.max_iterations, app.task_manager.get_next_task
f"Starting processing with llm call budget: {llm_call_budget}, max depth: {params.max_depth}, and max priority: {params.max_priority}"
)

class OverallResult(TypedDict):
Expand All @@ -599,7 +605,9 @@ class OverallResult(TypedDict):

simple_chat_message("Running validators...")

for task in next_task_fn(params.max_priority, params.max_depth):
for task in app.task_manager.get_next_task(
params.max_priority, params.max_depth
):
app.log.debug(f"Executing task {task.__class__.__name__}: {task}")
if hasattr(task, "message"):
simple_chat_message(
Expand All @@ -615,7 +623,16 @@ class OverallResult(TypedDict):
# get the ignored tasks set
pre_task_ignored_tasks = set(app.task_manager.ignored_tasks)

result = app.task_manager.execute_task(task)
try:
result = app.task_manager.execute_task(task)
except LLMCallBudgetReached:
# We can no longer solve any problems, we'll need to put the task back in the queue and return now
app.task_manager.priority_queue.push(task)
app.log.info("LLM call budget reached, ending the loop early")
simple_chat_message(
"LLM call budget has been reached. Stopping work..."
)
break

app.log.debug(f"Task {task.__class__.__name__}, result: {result}")
# simple_chat_message(f"Got result! Encountered errors: {result.encountered_errors}. Modified files: {result.modified_files}.")
Expand Down Expand Up @@ -706,30 +723,3 @@ class OverallResult(TypedDict):

app.log.error(e)
raise


P = ParamSpec("P")
R = TypeVar("R")


def scoped_task_fn(
max_iterations: Optional[int], next_task_fn: Callable[P, Generator[R, Any, None]]
) -> Callable[P, Generator[R, Any, None]]:
log = get_logger("fn_selection")
if max_iterations is None:
log.debug("No max_iterations, returning default get_next_task")
return next_task_fn

def inner(*args: P.args, **kwargs: P.kwargs) -> Generator[R, Any, None]:
log.info(f"In inner {args}, {kwargs}")
generator = next_task_fn(*args, **kwargs)
for i in range(max_iterations):
try:
log.debug(f"Yielding on iteration {i}")
yield next(generator)
except StopIteration:
break

log.debug("Returning the iteration-scoped get_next_task function")

return inner

0 comments on commit 3e25470

Please sign in to comment.