Skip to content

Commit

Permalink
✨ Make Kai RPC Server fail initialize call when model provider cann…
Browse files Browse the repository at this point in the history
…ot be queried (#614)

* Added model environment validation

Signed-off-by: JonahSussman <[email protected]>

* Further work

Signed-off-by: JonahSussman <[email protected]>

* Made trunk happy

Signed-off-by: JonahSussman <[email protected]>

* Pinned all opentelemetry packages

Signed-off-by: JonahSussman <[email protected]>

* Fixed tests and e2e demo

Signed-off-by: JonahSussman <[email protected]>

* Fixed review comments

Signed-off-by: JonahSussman <[email protected]>

---------

Signed-off-by: JonahSussman <[email protected]>
  • Loading branch information
JonahSussman authored Feb 5, 2025
1 parent cf5ffb2 commit 7cba2a0
Show file tree
Hide file tree
Showing 14 changed files with 246 additions and 101 deletions.
44 changes: 41 additions & 3 deletions .trunk/configs/custom-words.txt
Original file line number Diff line number Diff line change
@@ -1,16 +1,29 @@
analyzerfix
ASGI
caikit
celerybeat
codeplan
codeplanner
Connor
coolstore
coolstuff
customise
Cython
deepseek
diffable
djzager
dmypy
dotenv
fabianvf
flywaydb
frobinate
genai
gpgsign
htmlcov
imgui
initialcontextfactory
ipynb
ipython
jakartaee
javaee
javax
Expand All @@ -23,24 +36,49 @@ konveyor
langchain
levelname
LOGLEVEL
logr
microprofile
mistralai
mixtral
mkdocs
moderations
mypy
mypyc
nbdev
nosetests
OBJC
Ollama
pgaikwad
picketlink
pipenv
Pipfile
prio
pybuilder
pycache
pydantic
pyenv
pyflow
pyinstaller
pylspclient
pypa
pypackages
pytest
PYTHONPATH
pytype
quarkus
resteasy
ropeproject
Scrapy
sdist
smallrye
springboot
Spyder
spyderproject
spyproject
sussman
templ
tgis
tiiuae
venv
webassets
webmvc
prio
analyzerfix
fabianvf
28 changes: 26 additions & 2 deletions kai/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,22 @@ class CachePathResolver(ABC):
"""

@abstractmethod
def cache_path(self) -> Path: ...
def cache_path(self) -> Path:
"""
Generates a path to store cache
NOTE: This method should only be called once per desired path! You
should store the result in a variable if you need to use it multiple
times.
"""
...

@abstractmethod
def cache_meta(self) -> dict[str, str]: ...
def cache_meta(self) -> dict[str, str]:
"""
Generates metadata to store with cache
"""
...


class Cache(ABC):
Expand Down Expand Up @@ -248,3 +260,15 @@ def cache_path(self) -> Path:
path = self._dfs(self.task) / f"{self._req_count}_{self.request_type}.json"
self._req_count += 1
return path


class SimplePathResolver(CachePathResolver):
def __init__(self, path: Path | str, meta: Optional[dict[str, str]] = None) -> None:
self.path = Path(path)
self.meta = meta

def cache_path(self) -> Path:
return self.path

def cache_meta(self) -> dict[str, str]:
return self.meta or {}
26 changes: 26 additions & 0 deletions kai/data/llm_cache/kai-test-generation/validate_environment.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"input": "",
"output": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"schema",
"messages",
"AIMessage"
],
"kwargs": {
"content": "Hello",
"response_metadata": {
"finish_reason": "length",
"model_name": "kai-test-generation",
"system_fingerprint": "fp_50cad350e4"
},
"type": "ai",
"id": "run-b3817ca3-910f-4b5b-9a89-cd3a579c234a-0",
"tool_calls": [],
"invalid_tool_calls": []
}
},
"meta": {}
}
12 changes: 11 additions & 1 deletion kai/kai_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,18 @@ class KaiConfigIncidentStore(BaseModel):
# Model providers


class SupportedModelProviders(StrEnum):
CHAT_OLLAMA = "ChatOllama"
CHAT_OPENAI = "ChatOpenAI"
CHAT_BEDROCK = "ChatBedrock"
FAKE_LIST_CHAT_MODEL = "FakeListChatModel"
CHAT_GOOGLE_GENERATIVE_AI = "ChatGoogleGenerativeAI"
AZURE_CHAT_OPENAI = "AzureChatOpenAI"
CHAT_DEEP_SEEK = "ChatDeepSeek"


class KaiConfigModels(BaseModel):
provider: str
provider: SupportedModelProviders
args: dict[str, Any] = Field(default_factory=dict)
template: Optional[str] = Field(default=None)
llama_header: Optional[bool] = Field(default=None)
Expand Down
125 changes: 55 additions & 70 deletions kai/llm_interfacing/model_provider.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import datetime
import json
import os
from typing import Any, Optional

Expand All @@ -10,14 +8,14 @@
from langchain_core.language_models.base import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables import ConfigurableField, RunnableConfig
from langchain_deepseek import ChatDeepSeek
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_ollama import ChatOllama
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from pydantic.v1.utils import deep_update

from kai.cache import Cache, CachePathResolver
from kai.cache import Cache, CachePathResolver, SimplePathResolver
from kai.kai_config import KaiConfigModels
from kai.logging.logging import get_logger

Expand All @@ -40,6 +38,7 @@ def __init__(
defaults: dict[str, Any]
model_args: dict[str, Any]
model_id: str

# Set the model class, model args, and model id based on the provider
match config.provider:
case "ChatOllama":
Expand All @@ -61,9 +60,6 @@ def __init__(
defaults = {
"model": "gpt-3.5-turbo",
"temperature": 0.1,
# "model_kwargs": {
# "max_tokens": None,
# },
"streaming": True,
}

Expand Down Expand Up @@ -159,86 +155,75 @@ def __init__(
else:
self.template = config.template

if config.llama_header is None:
self.llama_header = self.model_id in [
"mistralai/mistral-7b-instruct-v0-2",
"mistralai/mixtral-8x7b-instruct-v01",
"codellama/codellama-34b-instruct",
"codellama/codellama-70b-instruct",
"deepseek-ai/deepseek-coder-33b-instruct",
"tiiuae/falcon-180b",
"tiiuae/falcon-40b",
"ibm/falcon-40b-8lang-instruct",
"meta-llama/llama-2-70b-chat",
"meta-llama/llama-2-13b-chat",
"meta-llama/llama-2-7b",
"meta-llama/llama-3-70b-instruct",
"meta-llama/llama-3-8b-instruct",
]
else:
self.llama_header = config.llama_header
def validate_environment(
self,
) -> None:
"""
Raises an exception if the environment is not set up correctly for the
current model provider.
"""

cpr = SimplePathResolver("validate_environment.json")

def challenge(k: str) -> BaseMessage:
return self.invoke("", cpr, configurable_fields={k: 1})

if isinstance(self.llm, ChatOllama):
challenge("max_tokens")
elif isinstance(self.llm, ChatOpenAI):
challenge("max_tokens")
elif isinstance(self.llm, ChatBedrock):
challenge("max_tokens")
elif isinstance(self.llm, FakeListChatModel):
pass
elif isinstance(self.llm, ChatGoogleGenerativeAI):
challenge("max_output_tokens")
elif isinstance(self.llm, AzureChatOpenAI):
challenge("max_tokens")
elif isinstance(self.llm, ChatDeepSeek):
challenge("max_tokens")

def invoke(
self,
input: LanguageModelInput,
cache_path_resolver: CachePathResolver,
cache_path_resolver: Optional[CachePathResolver] = None,
config: Optional[RunnableConfig] = None,
*,
configurable_fields: Optional[dict[str, Any]] = None,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> BaseMessage:
# Some fields can only be configured when the model is instantiated.
# This side-steps that by creating a new instance of the model with the
# configurable fields set, then invoking that new instance.
if configurable_fields is not None:
invoke_llm = self.llm.configurable_fields(
**{k: ConfigurableField(id=k) for k in configurable_fields}
).with_config(
configurable_fields # type: ignore[arg-type]
)
else:
invoke_llm = self.llm

if not (self.cache and cache_path_resolver):
return invoke_llm.invoke(input, config, stop=stop, **kwargs)

cache_path = cache_path_resolver.cache_path()
cache_meta = cache_path_resolver.cache_meta()

if self.demo_mode and self.cache:
if self.demo_mode:
cache_entry = self.cache.get(path=cache_path, input=input)

if cache_entry:
return cache_entry

response = self.llm.invoke(input, config, stop=stop, **kwargs)
response = invoke_llm.invoke(input, config, stop=stop, **kwargs)

if self.cache:
self.cache.put(
path=cache_path,
input=input,
output=response,
cache_meta=cache_meta,
)
self.cache.put(
path=cache_path,
input=input,
output=response,
cache_meta=cache_meta,
)

return response


# TODO(Shawn): Remove when we get to config update that
def str_to_bool(val: str) -> bool:
"""
Convert a string representation of truth to true (1) or false (0).
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
'val' is anything else.
"""
val = val.lower()
if val in ("y", "yes", "t", "true", "on", "1"):
return True
elif val in ("n", "no", "f", "false", "off", "0"):
return False
else:
raise ValueError("invalid truth value %r" % (val,))


def get_env_bool(key: str, default: Optional[bool] = None) -> bool | None:
"""
Get a boolean value from an environment variable, returning the default if
the variable is not set.
"""
val = os.getenv(key)
if val is None:
return default
return str_to_bool(val)


class DatetimeEncoder(json.JSONEncoder):
def default(self, obj: Any) -> Any:
if isinstance(obj, datetime.datetime):
return obj.isoformat()
return super().default(obj)
16 changes: 11 additions & 5 deletions kai/reactive_codeplanner/agent/dependency_agent/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,14 @@ def get_maven_query_from_code(code: str) -> str:


def find_in_pom(path: Path) -> Callable[[str], FindInPomResponse]:
## Open XML file
## parse XML, find the dependency node if we have group and artifact we will return start_line and end_line for the full node
## If we don't have group and artifact, but we have dependencies, then we will find the start of the dependecies node. start_line and end_line will be the same. The start of the dependencies.
"""
- Open XML file
- Parse XML, find the dependency node if we have group and artifact we
will return start_line and end_line for the full node
- If we don't have group and artifact, but we have dependencies, then we
will find the start of the dependencies node. start_line and end_line
will be the same. The start of the dependencies.
"""
tag_to_kwargs = {
"{http://maven.apache.org/POM/4.0.0}artifactId": "artifactId",
"{http://maven.apache.org/POM/4.0.0}groupId": "groupId",
Expand All @@ -100,9 +105,10 @@ def f(code: str) -> FindInPomResponse:
root = tree.getroot()
deps = root.findall("*//{http://maven.apache.org/POM/4.0.0}dependency")
index = code.index("keywords")
# Remove 8 chars to get ride of keyword=
# Remove 8 chars to get rid of keyword=
code_string = code[index + 9 :].strip("(){}")
## We know when it is just an add operation, that the LLM gives us just the word dependencies
# We know when it is just an add operation, that the LLM gives us just
# the word dependencies
if "dependencies" in code_string:
return FindInPomResponse(override=False)

Expand Down
4 changes: 2 additions & 2 deletions kai/reactive_codeplanner/task_manager/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ def get_next_task(
def initialize_priority_queue(self) -> None:
logger.info("Initializing task stacks.")

# When we re-initialze the priorty queue we need to start fresh
# Assume that something has changed in the project and we nned
# When we re-initialize the priority queue we need to start fresh
# Assume that something has changed in the project and we need
# to re-create the priority queue
self._stale_validated_files = []
new_tasks = self.run_validators()
Expand Down
Loading

0 comments on commit 7cba2a0

Please sign in to comment.