-
Notifications
You must be signed in to change notification settings - Fork 38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
✨ Replace max_iterations with max_llm_calls #628
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,17 @@ | |
LOG = get_logger(__name__) | ||
|
||
|
||
class LLMCallBudgetReached(Exception): | ||
def __init__( | ||
self, message: str = "The defined LLM call budget has been reached" | ||
) -> None: | ||
super().__init__(message) | ||
|
||
|
||
class ModelProvider: | ||
|
||
llm_call_budget: int = -1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this approach going to cause issues with Jonah's async PR? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think so. Each llm call can still check if the budget is reached before actually making the call. Asyncio is cooperatively multitasked There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know. wouldn't multple calls to get_code_plan_solution cause the budget for ALL calls in the system to be reset (AFAICT Model provider is shared). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are we supporting multiple calls? The whole system kind of breaks down in that case anyway doesn't it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is, but it might be nice to have this not be one more thing that we have to remember if/when we do have multiple requests was my other thought. But maybe we have to re-architect the initialization of task managers/agents at that time. |
||
|
||
def __init__( | ||
self, | ||
config: KaiConfigModels, | ||
|
@@ -33,6 +43,7 @@ def __init__( | |
self.llm_retry_delay: float = config.llm_retry_delay | ||
self.demo_mode: bool = demo_mode | ||
self.cache = cache | ||
self.llm_call_budget = config.llm_call_budget | ||
|
||
model_class: type[BaseChatModel] | ||
defaults: dict[str, Any] | ||
|
@@ -193,6 +204,8 @@ def invoke( | |
stop: Optional[list[str]] = None, | ||
**kwargs: Any, | ||
) -> BaseMessage: | ||
if self.llm_call_budget == 0: | ||
raise LLMCallBudgetReached | ||
# Some fields can only be configured when the model is instantiated. | ||
# This side-steps that by creating a new instance of the model with the | ||
# configurable fields set, then invoking that new instance. | ||
|
@@ -206,6 +219,7 @@ def invoke( | |
invoke_llm = self.llm | ||
|
||
if not (self.cache and cache_path_resolver): | ||
self.llm_call_budget -= 1 | ||
return invoke_llm.invoke(input, config, stop=stop, **kwargs) | ||
|
||
cache_path = cache_path_resolver.cache_path() | ||
|
@@ -217,6 +231,7 @@ def invoke( | |
if cache_entry: | ||
return cache_entry | ||
|
||
self.llm_call_budget -= 1 | ||
response = invoke_llm.invoke(input, config, stop=stop, **kwargs) | ||
|
||
self.cache.put( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,16 +6,7 @@ | |
from itertools import groupby | ||
from operator import attrgetter | ||
from pathlib import Path | ||
from typing import ( | ||
Any, | ||
Callable, | ||
Generator, | ||
Optional, | ||
ParamSpec, | ||
TypedDict, | ||
TypeVar, | ||
cast, | ||
) | ||
from typing import Any, Optional, TypedDict, cast | ||
from unittest.mock import MagicMock | ||
from urllib.parse import urlparse | ||
|
||
|
@@ -31,7 +22,7 @@ | |
from kai.jsonrpc.models import JsonRpcError, JsonRpcErrorCode, JsonRpcId | ||
from kai.jsonrpc.util import AutoAbsPath, AutoAbsPathExists, CamelCaseBaseModel | ||
from kai.kai_config import KaiConfigModels | ||
from kai.llm_interfacing.model_provider import ModelProvider | ||
from kai.llm_interfacing.model_provider import LLMCallBudgetReached, ModelProvider | ||
from kai.logging.logging import TRACE, KaiLogConfig, get_logger | ||
from kai.reactive_codeplanner.agent.analyzer_fix.agent import AnalyzerAgent | ||
from kai.reactive_codeplanner.agent.dependency_agent.dependency_agent import ( | ||
|
@@ -121,6 +112,7 @@ def __init__(self) -> None: | |
self.analysis_validator: Optional[AnalyzerLSPStep] = None | ||
self.task_manager: Optional[TaskManager] = None | ||
self.rcm: Optional[RepoContextManager] = None | ||
self.model_provider: Optional[ModelProvider] = None | ||
|
||
|
||
app = KaiRpcApplication() | ||
|
@@ -201,6 +193,7 @@ def initialize( | |
cache.model_id = re.sub(r"[\.:\\/]", "_", model_provider.model_id) | ||
|
||
model_provider.validate_environment() | ||
app.model_provider = model_provider | ||
except Exception as e: | ||
server.send_response( | ||
id=id, | ||
|
@@ -463,7 +456,9 @@ class GetCodeplanAgentSolutionParams(BaseModel): | |
file_path: Path | ||
incidents: list[ExtendedIncident] | ||
|
||
# Deprecated in favor of llm_call_budget | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://docs.pydantic.dev/latest/concepts/fields/#deprecated-fields You can make this a bonafide deprecated field if you want There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or you could alias the other field to the new one There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 that's a good idea |
||
max_iterations: Optional[int] = None | ||
llm_call_budget: Optional[int] = None | ||
max_depth: Optional[int] = None | ||
max_priority: Optional[int] = None | ||
|
||
|
@@ -478,6 +473,7 @@ def get_codeplan_agent_solution( | |
id: JsonRpcId, | ||
params: GetCodeplanAgentSolutionParams, | ||
) -> None: | ||
|
||
def simple_chat_message(msg: str) -> None: | ||
app.log.info("simple_chat_message!") | ||
server.send_notification( | ||
|
@@ -491,6 +487,20 @@ def simple_chat_message(msg: str) -> None: | |
}, | ||
) | ||
|
||
# This means unlimited calls, as counting down won't reach 0 | ||
llm_call_budget: int = -1 | ||
if params.max_iterations and not params.llm_call_budget: | ||
llm_call_budget = params.max_iterations | ||
elif params.llm_call_budget is not None: | ||
llm_call_budget = llm_call_budget | ||
|
||
if app.model_provider is not None: | ||
# Set LLM call budget for the new request | ||
app.model_provider.llm_call_budget = llm_call_budget | ||
else: | ||
# We should never hit this branch, this is mostly to make mypy happy | ||
raise Exception("ModelProvider not initialized") | ||
|
||
simple_chat_message("Starting!") | ||
|
||
try: | ||
|
@@ -571,14 +581,10 @@ def simple_chat_message(msg: str) -> None: | |
app.task_manager.set_seed_tasks(*seed_tasks) | ||
|
||
app.log.info( | ||
f"Starting code plan loop with iterations: {params.max_iterations}, max depth: {params.max_depth}, and max priority: {params.max_priority}" | ||
f"Starting code plan loop with llm call budget: {llm_call_budget}, max depth: {params.max_depth}, and max priority: {params.max_priority}" | ||
) | ||
simple_chat_message( | ||
f"Starting processing with iterations: {params.max_iterations}, max depth: {params.max_depth}, and max priority: {params.max_priority}" | ||
) | ||
|
||
next_task_fn = scoped_task_fn( | ||
params.max_iterations, app.task_manager.get_next_task | ||
f"Starting processing with llm call budget: {llm_call_budget}, max depth: {params.max_depth}, and max priority: {params.max_priority}" | ||
) | ||
|
||
class OverallResult(TypedDict): | ||
|
@@ -599,7 +605,9 @@ class OverallResult(TypedDict): | |
|
||
simple_chat_message("Running validators...") | ||
|
||
for task in next_task_fn(params.max_priority, params.max_depth): | ||
for task in app.task_manager.get_next_task( | ||
params.max_priority, params.max_depth | ||
): | ||
app.log.debug(f"Executing task {task.__class__.__name__}: {task}") | ||
if hasattr(task, "message"): | ||
simple_chat_message( | ||
|
@@ -615,7 +623,16 @@ class OverallResult(TypedDict): | |
# get the ignored tasks set | ||
pre_task_ignored_tasks = set(app.task_manager.ignored_tasks) | ||
|
||
result = app.task_manager.execute_task(task) | ||
try: | ||
result = app.task_manager.execute_task(task) | ||
except LLMCallBudgetReached: | ||
# We can no longer solve any problems, we'll need to put the task back in the queue and return now | ||
app.task_manager.priority_queue.push(task) | ||
app.log.info("LLM call budget reached, ending the loop early") | ||
simple_chat_message( | ||
"LLM call budget has been reached. Stopping work..." | ||
) | ||
break | ||
|
||
app.log.debug(f"Task {task.__class__.__name__}, result: {result}") | ||
# simple_chat_message(f"Got result! Encountered errors: {result.encountered_errors}. Modified files: {result.modified_files}.") | ||
|
@@ -706,30 +723,3 @@ class OverallResult(TypedDict): | |
|
||
app.log.error(e) | ||
raise | ||
|
||
|
||
P = ParamSpec("P") | ||
R = TypeVar("R") | ||
|
||
|
||
def scoped_task_fn( | ||
max_iterations: Optional[int], next_task_fn: Callable[P, Generator[R, Any, None]] | ||
) -> Callable[P, Generator[R, Any, None]]: | ||
log = get_logger("fn_selection") | ||
if max_iterations is None: | ||
log.debug("No max_iterations, returning default get_next_task") | ||
return next_task_fn | ||
|
||
def inner(*args: P.args, **kwargs: P.kwargs) -> Generator[R, Any, None]: | ||
log.info(f"In inner {args}, {kwargs}") | ||
generator = next_task_fn(*args, **kwargs) | ||
for i in range(max_iterations): | ||
try: | ||
log.debug(f"Yielding on iteration {i}") | ||
yield next(generator) | ||
except StopIteration: | ||
break | ||
|
||
log.debug("Returning the iteration-scoped get_next_task function") | ||
|
||
return inner |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we say the budget that was set?