Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Added preliminary noisy notifications #596

Merged
merged 11 commits into from
Feb 4, 2025
17 changes: 14 additions & 3 deletions example/run_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time
from io import BufferedReader, BufferedWriter
from pathlib import Path
from typing import Generator, cast
from typing import Any, Generator, cast

from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
Expand All @@ -20,9 +20,9 @@
sys.path.append("../../")
from kai.analyzer_types import ExtendedIncident, Report
from kai.jsonrpc.core import JsonRpcServer
from kai.jsonrpc.models import JsonRpcError, JsonRpcResponse
from kai.jsonrpc.models import JsonRpcError, JsonRpcId, JsonRpcResponse
from kai.jsonrpc.streams import LspStyleStream
from kai.logging.logging import get_logger, init_logging_from_log_config
from kai.logging.logging import TRACE, get_logger, init_logging_from_log_config
from kai.rpc_server.server import (
GetCodeplanAgentSolutionParams,
KaiRpcApplication,
Expand Down Expand Up @@ -93,6 +93,16 @@ def initialize_rpc_server(
log.info(rpc_subprocess.args)

app = KaiRpcApplication()
log.setLevel(TRACE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, this was just for testing. Removing.


@app.add_notify(method="my_progress")
def blah(
app: KaiRpcApplication,
server: JsonRpcServer,
id: JsonRpcId,
params: dict[str, Any],
) -> None:
log.info(f"Received my_progress: {params}")

rpc_server = JsonRpcServer(
json_rpc_stream=LspStyleStream(
Expand Down Expand Up @@ -179,6 +189,7 @@ def process_file(
max_priority=0,
max_depth=0,
max_iterations=len(incidents),
chat_token=str("123e4567-e89b-12d3-a456-426614174000"),
)

KAI_LOG.debug(f"Request is: {params.model_dump()}")
Expand Down
13 changes: 7 additions & 6 deletions kai/jsonrpc/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,6 @@ def add(
if method is None:
raise ValueError("Method name must be provided")

if kind == "request":
callbacks = self.request_callbacks
else:
callbacks = self.notify_callbacks

def decorator(
func: JsonRpcCallable,
) -> JsonRpcCallback:
Expand All @@ -117,7 +112,13 @@ def decorator(
kind=kind,
method=method,
)
callbacks[method] = callback

if kind == "request":
self.request_callbacks[method] = callback
else:
self.notify_callbacks[method] = callback

log.error(f"Added {kind} callback: {method}")

return callback

Expand Down
10 changes: 6 additions & 4 deletions kai/jsonrpc/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def recv(self) -> JsonRpcError | JsonRpcRequest | JsonRpcResponse | None: ...

def dump_json_no_infinite_recursion(msg: JsonRpcRequest | JsonRpcResponse) -> str:
if not isinstance(msg, JsonRpcRequest) or msg.method != "logMessage":
return msg.model_dump_json()
# exclude_none = True because `None` serializes as `null`, which is not
# the same thing as `undefined` in JS
return msg.model_dump_json(exclude_none=True)
else:
log_msg = msg.model_copy()
if log_msg.params is None:
Expand All @@ -80,7 +82,7 @@ def dump_json_no_infinite_recursion(msg: JsonRpcRequest | JsonRpcResponse) -> st
if hasattr(log_msg.params, "message"):
log_msg.params.message = "<omitted>"

return log_msg.model_dump_json()
return log_msg.model_dump_json(exclude_none=True)


class LspStyleStream(JsonRpcStream):
Expand All @@ -94,7 +96,7 @@ class LspStyleStream(JsonRpcStream):
TYPE_HEADER = "Content-Type: "

def send(self, msg: JsonRpcRequest | JsonRpcResponse) -> None:
json_str = msg.model_dump_json()
json_str = msg.model_dump_json(exclude_none=True)
json_req = f"Content-Length: {len(json_str.encode('utf-8'))}\r\n\r\n{json_str}"

log.log(TRACE, "Sending request: %s", dump_json_no_infinite_recursion(msg))
Expand Down Expand Up @@ -198,7 +200,7 @@ def __init__(
self.log = log

def send(self, msg: JsonRpcRequest | JsonRpcResponse) -> None:
json_req = f"{msg.model_dump_json()}\n"
json_req = f"{msg.model_dump_json(exclude_none=True)}\n"

log.log(TRACE, "Sending request: %s", dump_json_no_infinite_recursion(msg))

Expand Down
41 changes: 1 addition & 40 deletions kai/llm_interfacing/model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,15 @@
import os
from typing import Any, Optional

from genai import Client, Credentials
from genai.extensions.langchain.chat_llm import LangChainChatInterface
from genai.schema import DecodingMethod
from langchain_aws import ChatBedrock
from langchain_community.chat_models import ChatOllama
from langchain_community.chat_models.fake import FakeListChatModel
from langchain_core.language_models.base import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage
from langchain_core.runnables import RunnableConfig
from langchain_deepseek import ChatDeepSeek
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_ollama import ChatOllama
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from pydantic.v1.utils import deep_update

Expand Down Expand Up @@ -73,42 +70,6 @@ def __init__(
model_args = deep_update(defaults, config.args)
model_id = model_args["model"]

case "ChatIBMGenAI":
model_class = LangChainChatInterface
if get_env_bool("KAI__DEMO_MODE", False):
api_key = os.getenv("GENAI_KEY", "dummy_value")
api_endpoint = os.getenv("GENAI_API", "")
credentials = Credentials(
api_key=api_key, api_endpoint=api_endpoint
)
else:
credentials = Credentials.from_env()
defaults = {
"client": Client(credentials=credentials),
"model_id": "ibm-mistralai/mixtral-8x7b-instruct-v01-q",
"parameters": {
"decoding_method": DecodingMethod.SAMPLE,
# NOTE: probably have to do some more clever stuff regarding
# config. max_new_tokens and such varies between models
"max_new_tokens": 4096,
"min_new_tokens": 10,
"temperature": 0.05,
"top_k": 20,
"top_p": 0.9,
"return_options": {"input_text": False, "input_tokens": True},
},
"moderations": {
# Threshold is set to very low level to flag everything
# (testing purposes) or set to True to enable HAP with
# default settings
"hap": {"input": True, "output": False, "threshold": 0.01}
},
"streaming": True,
}

model_args = deep_update(defaults, config.args)
model_id = model_args["model_id"]

case "ChatBedrock":
model_class = ChatBedrock

Expand Down
81 changes: 60 additions & 21 deletions kai/rpc_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,8 @@ class GetCodeplanAgentSolutionParams(BaseModel):
max_depth: Optional[int] = None
max_priority: Optional[int] = None

chat_token: str


@app.add_request(method="getCodeplanAgentSolution")
@tracer.start_as_current_span("get_codeplan_solution")
Expand All @@ -466,6 +468,21 @@ def get_codeplan_agent_solution(
id: JsonRpcId,
params: GetCodeplanAgentSolutionParams,
) -> None:
def simple_chat_message(msg: str) -> None:
app.log.info("simple_chat_message!")
server.send_notification(
method="my_progress",
params={
"chatToken": params.chat_token,
"kind": "SimpleChatMessage",
"value": {
"message": msg,
},
},
)

simple_chat_message("Starting!")

try:
# create a set of AnalyzerRuleViolations
# seed the task manager with these violations
Expand Down Expand Up @@ -500,31 +517,36 @@ def get_codeplan_agent_solution(
if platform.system() == "Windows":
uri_path = uri_path.removeprefix("/")

seed_tasks.append(
class_to_use(
file=str(Path(uri_path).absolute()),
line=incident.line_number,
column=-1, # Not contained within report?
message=incident.message,
priority=0,
incident=Incident(**incident.model_dump()),
violation=Violation(
id=incident.violation_name or "",
description=incident.violation_description or "",
category=incident.violation_category,
labels=incident.violation_labels,
),
ruleset=RuleSet(
name=incident.ruleset_name,
description=incident.ruleset_description or "",
),
)
seed_task = class_to_use(
file=str(Path(uri_path).absolute()),
line=incident.line_number,
column=-1, # Not contained within report?
message=incident.message,
priority=0,
incident=Incident(**incident.model_dump()),
violation=Violation(
id=incident.violation_name or "",
description=incident.violation_description or "",
category=incident.violation_category,
labels=incident.violation_labels,
),
ruleset=RuleSet(
name=incident.ruleset_name,
description=incident.ruleset_description or "",
),
)

seed_tasks.append(seed_task)

app.task_manager.set_seed_tasks(*seed_tasks)

app.log.info(
f"starting code plan loop with iterations: {params.max_iterations}, max depth: {params.max_depth}, and max priority: {params.max_priority}"
f"Starting code plan loop with iterations: {params.max_iterations}, max depth: {params.max_depth}, and max priority: {params.max_priority}"
)
simple_chat_message(
f"Starting processing with iterations: {params.max_iterations}, max depth: {params.max_depth}, and max priority: {params.max_priority}"
)

next_task_fn = scoped_task_fn(
params.max_iterations, app.task_manager.get_next_task
)
Expand All @@ -545,8 +567,18 @@ class OverallResult(TypedDict):
# get the ignored tasks set
initial_ignored_tasks = app.task_manager.ignored_tasks

simple_chat_message("Running validators...")

for task in next_task_fn(params.max_priority, params.max_depth):
app.log.debug(f"Executing task {task.__class__.__name__}: {task}")
if hasattr(task, "message"):
simple_chat_message(
f"Executing task {task.__class__.__name__} ({task.message}), from: {task.oldest_ancestor().__class__.__name__}."
)
else:
simple_chat_message(
f"Executing task {task.__class__.__name__}, from: {task.oldest_ancestor().__class__.__name__}."
)

# get the solved tasks set
pre_task_solved_tasks = app.task_manager.processed_tasks
Expand All @@ -555,7 +587,9 @@ class OverallResult(TypedDict):

result = app.task_manager.execute_task(task)

app.log.debug(f"Task {task.__class__.__name__} result: {result}")
app.log.debug(f"Task {task.__class__.__name__}, result: {result}")
# simple_chat_message(f"Got result! Encountered errors: {result.encountered_errors}. Modified files: {result.modified_files}.")
simple_chat_message("Finished task!")

app.task_manager.supply_result(result)

Expand Down Expand Up @@ -613,6 +647,8 @@ class OverallResult(TypedDict):
app.log.debug(f"QUEUE_STATE: IGNORED_TASKS: {task}")
app.log.debug("QUEUE_STATE: IGNORED_TASKS: END")

simple_chat_message("Running validators...")

# after we have completed all the tasks, we should show what has been accomplished for this particular solution
app.log.debug("QUEUE_STATE_END_OF_CODE_PLAN: SUCCESSFUL TASKS: START")
for task in app.task_manager.processed_tasks - initial_solved_tasks:
Expand All @@ -622,11 +658,14 @@ class OverallResult(TypedDict):
for task in set(app.task_manager.ignored_tasks) - set(initial_ignored_tasks):
app.log.debug(f"QUEUE_STATE_SEED_TASKS: SUCCESSFUL_TASKS: {task}")
app.log.debug("QUEUE_STATE_END_OF_CODE_PLAN: IGNORED_TASKS: END")

diff = app.rcm.snapshot.diff(agent_solution_snapshot)
overall_result["diff"] = diff[1] + diff[2]

app.rcm.reset(agent_solution_snapshot)

simple_chat_message("Finished!")

server.send_response(
id=id,
result=dict(overall_result),
Expand Down
14 changes: 7 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,28 @@ dependencies = [
"aiohttp==3.8.6; python_version < '3.12'",
"aiohttp==3.9.3; python_version >= '3.12'",
"gitpython==3.1.43",
"pydantic==2.8.2",
"pydantic==2.10.6",
"pydantic-settings==2.4.0",
"requests==2.32.3",
"pygments==2.18.0",
"python-dateutil==2.8.2",
"ibm-generative-ai==2.2.0",
"Jinja2==3.1.4",
"langchain==0.3.1",
"langchain==0.3.17",
"langchain-community==0.3.1",
"langchain-openai==0.3.2",
"langchain-google-genai==2.0.4",
"langchain-aws==0.2.1",
"langchain-ollama==0.2.3",
"langchain-google-genai==2.0.9",
"langchain-aws==0.2.11",
"langchain-experimental==0.3.2",
"langchain-deepseek-official==0.1.0",
"gunicorn==22.0.0",
"tree-sitter==0.22.3",
"tree-sitter-java==0.21.0",
"sequoia-diff==0.0.8",
"sequoia-diff>=0.0.9",
"python-dotenv==1.0.1",
"pyyaml==6.0.1",
"lxml==5.3.0",
"boto3==1.34.157", # Allows Amazon Bedrock to work
"boto3==1.36.9", # Allows Amazon Bedrock to work
"pylspclient==0.1.2", # used for talking to RPC clients over stdin/stdout
"opentelemetry-sdk",
"opentelemetry-api",
Expand Down
Loading