Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Replace max_iterations with max_llm_calls #628

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions kai/kai_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ class KaiConfigModels(BaseModel):
llama_header: Optional[bool] = Field(default=None)
llm_retries: int = 5
llm_retry_delay: float = 10.0
llm_call_budget: int = -1


# Main config
Expand Down
15 changes: 15 additions & 0 deletions kai/llm_interfacing/model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,17 @@
LOG = get_logger(__name__)


class LLMCallBudgetReached(Exception):
def __init__(
self, message: str = "The defined LLM call budget has been reached"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we say the budget that was set?

) -> None:
super().__init__(message)


class ModelProvider:

llm_call_budget: int = -1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this approach going to cause issues with Jonah's async PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. Each llm call can still check if the budget is reached before actually making the call. Asyncio is cooperatively multitasked

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know. wouldn't multple calls to get_code_plan_solution cause the budget for ALL calls in the system to be reset (AFAICT Model provider is shared).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we supporting multiple calls? The whole system kind of breaks down in that case anyway doesn't it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is, but it might be nice to have this not be one more thing that we have to remember if/when we do have multiple requests was my other thought.

But maybe we have to re-architect the initialization of task managers/agents at that time.


def __init__(
self,
config: KaiConfigModels,
Expand All @@ -33,6 +43,7 @@ def __init__(
self.llm_retry_delay: float = config.llm_retry_delay
self.demo_mode: bool = demo_mode
self.cache = cache
self.llm_call_budget = config.llm_call_budget

model_class: type[BaseChatModel]
defaults: dict[str, Any]
Expand Down Expand Up @@ -193,6 +204,8 @@ def invoke(
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> BaseMessage:
if self.llm_call_budget == 0:
raise LLMCallBudgetReached
# Some fields can only be configured when the model is instantiated.
# This side-steps that by creating a new instance of the model with the
# configurable fields set, then invoking that new instance.
Expand All @@ -206,6 +219,7 @@ def invoke(
invoke_llm = self.llm

if not (self.cache and cache_path_resolver):
self.llm_call_budget -= 1
return invoke_llm.invoke(input, config, stop=stop, **kwargs)

cache_path = cache_path_resolver.cache_path()
Expand All @@ -217,6 +231,7 @@ def invoke(
if cache_entry:
return cache_entry

self.llm_call_budget -= 1
response = invoke_llm.invoke(input, config, stop=stop, **kwargs)

self.cache.put(
Expand Down
82 changes: 36 additions & 46 deletions kai/rpc_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,7 @@
from itertools import groupby
from operator import attrgetter
from pathlib import Path
from typing import (
Any,
Callable,
Generator,
Optional,
ParamSpec,
TypedDict,
TypeVar,
cast,
)
from typing import Any, Optional, TypedDict, cast
from unittest.mock import MagicMock
from urllib.parse import urlparse

Expand All @@ -31,7 +22,7 @@
from kai.jsonrpc.models import JsonRpcError, JsonRpcErrorCode, JsonRpcId
from kai.jsonrpc.util import AutoAbsPath, AutoAbsPathExists, CamelCaseBaseModel
from kai.kai_config import KaiConfigModels
from kai.llm_interfacing.model_provider import ModelProvider
from kai.llm_interfacing.model_provider import LLMCallBudgetReached, ModelProvider
from kai.logging.logging import TRACE, KaiLogConfig, get_logger
from kai.reactive_codeplanner.agent.analyzer_fix.agent import AnalyzerAgent
from kai.reactive_codeplanner.agent.dependency_agent.dependency_agent import (
Expand Down Expand Up @@ -121,6 +112,7 @@ def __init__(self) -> None:
self.analysis_validator: Optional[AnalyzerLSPStep] = None
self.task_manager: Optional[TaskManager] = None
self.rcm: Optional[RepoContextManager] = None
self.model_provider: Optional[ModelProvider] = None


app = KaiRpcApplication()
Expand Down Expand Up @@ -201,6 +193,7 @@ def initialize(
cache.model_id = re.sub(r"[\.:\\/]", "_", model_provider.model_id)

model_provider.validate_environment()
app.model_provider = model_provider
except Exception as e:
server.send_response(
id=id,
Expand Down Expand Up @@ -463,7 +456,9 @@ class GetCodeplanAgentSolutionParams(BaseModel):
file_path: Path
incidents: list[ExtendedIncident]

# Deprecated in favor of llm_call_budget
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://docs.pydantic.dev/latest/concepts/fields/#deprecated-fields You can make this a bonafide deprecated field if you want

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or you could alias the other field to the new one

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 that's a good idea

max_iterations: Optional[int] = None
llm_call_budget: Optional[int] = None
max_depth: Optional[int] = None
max_priority: Optional[int] = None

Expand All @@ -478,6 +473,7 @@ def get_codeplan_agent_solution(
id: JsonRpcId,
params: GetCodeplanAgentSolutionParams,
) -> None:

def simple_chat_message(msg: str) -> None:
app.log.info("simple_chat_message!")
server.send_notification(
Expand All @@ -491,6 +487,20 @@ def simple_chat_message(msg: str) -> None:
},
)

# This means unlimited calls, as counting down won't reach 0
llm_call_budget: int = -1
if params.max_iterations and not params.llm_call_budget:
llm_call_budget = params.max_iterations
elif params.llm_call_budget is not None:
llm_call_budget = llm_call_budget

if app.model_provider is not None:
# Set LLM call budget for the new request
app.model_provider.llm_call_budget = llm_call_budget
else:
# We should never hit this branch, this is mostly to make mypy happy
raise Exception("ModelProvider not initialized")

simple_chat_message("Starting!")

try:
Expand Down Expand Up @@ -571,14 +581,10 @@ def simple_chat_message(msg: str) -> None:
app.task_manager.set_seed_tasks(*seed_tasks)

app.log.info(
f"Starting code plan loop with iterations: {params.max_iterations}, max depth: {params.max_depth}, and max priority: {params.max_priority}"
f"Starting code plan loop with llm call budget: {llm_call_budget}, max depth: {params.max_depth}, and max priority: {params.max_priority}"
)
simple_chat_message(
f"Starting processing with iterations: {params.max_iterations}, max depth: {params.max_depth}, and max priority: {params.max_priority}"
)

next_task_fn = scoped_task_fn(
params.max_iterations, app.task_manager.get_next_task
f"Starting processing with llm call budget: {llm_call_budget}, max depth: {params.max_depth}, and max priority: {params.max_priority}"
)

class OverallResult(TypedDict):
Expand All @@ -599,7 +605,9 @@ class OverallResult(TypedDict):

simple_chat_message("Running validators...")

for task in next_task_fn(params.max_priority, params.max_depth):
for task in app.task_manager.get_next_task(
params.max_priority, params.max_depth
):
app.log.debug(f"Executing task {task.__class__.__name__}: {task}")
if hasattr(task, "message"):
simple_chat_message(
Expand All @@ -615,7 +623,16 @@ class OverallResult(TypedDict):
# get the ignored tasks set
pre_task_ignored_tasks = set(app.task_manager.ignored_tasks)

result = app.task_manager.execute_task(task)
try:
result = app.task_manager.execute_task(task)
except LLMCallBudgetReached:
# We can no longer solve any problems, we'll need to put the task back in the queue and return now
app.task_manager.priority_queue.push(task)
app.log.info("LLM call budget reached, ending the loop early")
simple_chat_message(
"LLM call budget has been reached. Stopping work..."
)
break

app.log.debug(f"Task {task.__class__.__name__}, result: {result}")
# simple_chat_message(f"Got result! Encountered errors: {result.encountered_errors}. Modified files: {result.modified_files}.")
Expand Down Expand Up @@ -706,30 +723,3 @@ class OverallResult(TypedDict):

app.log.error(e)
raise


P = ParamSpec("P")
R = TypeVar("R")


def scoped_task_fn(
max_iterations: Optional[int], next_task_fn: Callable[P, Generator[R, Any, None]]
) -> Callable[P, Generator[R, Any, None]]:
log = get_logger("fn_selection")
if max_iterations is None:
log.debug("No max_iterations, returning default get_next_task")
return next_task_fn

def inner(*args: P.args, **kwargs: P.kwargs) -> Generator[R, Any, None]:
log.info(f"In inner {args}, {kwargs}")
generator = next_task_fn(*args, **kwargs)
for i in range(max_iterations):
try:
log.debug(f"Yielding on iteration {i}")
yield next(generator)
except StopIteration:
break

log.debug("Returning the iteration-scoped get_next_task function")

return inner
Loading