diff --git a/example/run_demo.py b/example/run_demo.py index 84d83df9..1ab59d9b 100755 --- a/example/run_demo.py +++ b/example/run_demo.py @@ -7,7 +7,7 @@ import time from io import BufferedReader, BufferedWriter from pathlib import Path -from typing import Generator, cast +from typing import Any, Generator, cast from opentelemetry import trace from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter @@ -20,7 +20,7 @@ sys.path.append("../../") from kai.analyzer_types import ExtendedIncident, Report from kai.jsonrpc.core import JsonRpcServer -from kai.jsonrpc.models import JsonRpcError, JsonRpcResponse +from kai.jsonrpc.models import JsonRpcError, JsonRpcId, JsonRpcResponse from kai.jsonrpc.streams import LspStyleStream from kai.logging.logging import get_logger, init_logging_from_log_config from kai.rpc_server.server import ( @@ -94,6 +94,15 @@ def initialize_rpc_server( app = KaiRpcApplication() + @app.add_notify(method="my_progress") + def blah( + app: KaiRpcApplication, + server: JsonRpcServer, + id: JsonRpcId, + params: dict[str, Any], + ) -> None: + log.info(f"Received my_progress: {params}") + rpc_server = JsonRpcServer( json_rpc_stream=LspStyleStream( cast(BufferedReader, rpc_subprocess.stdout), @@ -179,6 +188,7 @@ def process_file( max_priority=0, max_depth=0, max_iterations=len(incidents), + chat_token=str("123e4567-e89b-12d3-a456-426614174000"), ) KAI_LOG.debug(f"Request is: {params.model_dump()}") diff --git a/kai/jsonrpc/core.py b/kai/jsonrpc/core.py index ac230c18..19bb69be 100644 --- a/kai/jsonrpc/core.py +++ b/kai/jsonrpc/core.py @@ -104,11 +104,6 @@ def add( if method is None: raise ValueError("Method name must be provided") - if kind == "request": - callbacks = self.request_callbacks - else: - callbacks = self.notify_callbacks - def decorator( func: JsonRpcCallable, ) -> JsonRpcCallback: @@ -117,7 +112,13 @@ def decorator( kind=kind, method=method, ) - callbacks[method] = callback + + if kind == "request": + self.request_callbacks[method] = callback + else: + self.notify_callbacks[method] = callback + + log.error(f"Added {kind} callback: {method}") return callback diff --git a/kai/jsonrpc/streams.py b/kai/jsonrpc/streams.py index 92d70974..ddf41d0c 100644 --- a/kai/jsonrpc/streams.py +++ b/kai/jsonrpc/streams.py @@ -68,7 +68,9 @@ def recv(self) -> JsonRpcError | JsonRpcRequest | JsonRpcResponse | None: ... def dump_json_no_infinite_recursion(msg: JsonRpcRequest | JsonRpcResponse) -> str: if not isinstance(msg, JsonRpcRequest) or msg.method != "logMessage": - return msg.model_dump_json() + # exclude_none = True because `None` serializes as `null`, which is not + # the same thing as `undefined` in JS + return msg.model_dump_json(exclude_none=True) else: log_msg = msg.model_copy() if log_msg.params is None: @@ -80,7 +82,7 @@ def dump_json_no_infinite_recursion(msg: JsonRpcRequest | JsonRpcResponse) -> st if hasattr(log_msg.params, "message"): log_msg.params.message = "" - return log_msg.model_dump_json() + return log_msg.model_dump_json(exclude_none=True) class LspStyleStream(JsonRpcStream): @@ -94,7 +96,7 @@ class LspStyleStream(JsonRpcStream): TYPE_HEADER = "Content-Type: " def send(self, msg: JsonRpcRequest | JsonRpcResponse) -> None: - json_str = msg.model_dump_json() + json_str = msg.model_dump_json(exclude_none=True) json_req = f"Content-Length: {len(json_str.encode('utf-8'))}\r\n\r\n{json_str}" log.log(TRACE, "Sending request: %s", dump_json_no_infinite_recursion(msg)) @@ -198,7 +200,7 @@ def __init__( self.log = log def send(self, msg: JsonRpcRequest | JsonRpcResponse) -> None: - json_req = f"{msg.model_dump_json()}\n" + json_req = f"{msg.model_dump_json(exclude_none=True)}\n" log.log(TRACE, "Sending request: %s", dump_json_no_infinite_recursion(msg)) diff --git a/kai/llm_interfacing/model_provider.py b/kai/llm_interfacing/model_provider.py index 5f3f7546..c32bca8e 100644 --- a/kai/llm_interfacing/model_provider.py +++ b/kai/llm_interfacing/model_provider.py @@ -5,11 +5,7 @@ import os from typing import Any, Optional -from genai import Client, Credentials -from genai.extensions.langchain.chat_llm import LangChainChatInterface -from genai.schema import DecodingMethod from langchain_aws import ChatBedrock -from langchain_community.chat_models import ChatOllama from langchain_community.chat_models.fake import FakeListChatModel from langchain_core.language_models.base import LanguageModelInput from langchain_core.language_models.chat_models import BaseChatModel @@ -17,6 +13,7 @@ from langchain_core.runnables import RunnableConfig from langchain_deepseek import ChatDeepSeek from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_ollama import ChatOllama from langchain_openai import AzureChatOpenAI, ChatOpenAI from pydantic.v1.utils import deep_update @@ -73,42 +70,6 @@ def __init__( model_args = deep_update(defaults, config.args) model_id = model_args["model"] - case "ChatIBMGenAI": - model_class = LangChainChatInterface - if get_env_bool("KAI__DEMO_MODE", False): - api_key = os.getenv("GENAI_KEY", "dummy_value") - api_endpoint = os.getenv("GENAI_API", "") - credentials = Credentials( - api_key=api_key, api_endpoint=api_endpoint - ) - else: - credentials = Credentials.from_env() - defaults = { - "client": Client(credentials=credentials), - "model_id": "ibm-mistralai/mixtral-8x7b-instruct-v01-q", - "parameters": { - "decoding_method": DecodingMethod.SAMPLE, - # NOTE: probably have to do some more clever stuff regarding - # config. max_new_tokens and such varies between models - "max_new_tokens": 4096, - "min_new_tokens": 10, - "temperature": 0.05, - "top_k": 20, - "top_p": 0.9, - "return_options": {"input_text": False, "input_tokens": True}, - }, - "moderations": { - # Threshold is set to very low level to flag everything - # (testing purposes) or set to True to enable HAP with - # default settings - "hap": {"input": True, "output": False, "threshold": 0.01} - }, - "streaming": True, - } - - model_args = deep_update(defaults, config.args) - model_id = model_args["model_id"] - case "ChatBedrock": model_class = ChatBedrock diff --git a/kai/rpc_server/server.py b/kai/rpc_server/server.py index e47fa489..c60ce66b 100644 --- a/kai/rpc_server/server.py +++ b/kai/rpc_server/server.py @@ -457,6 +457,8 @@ class GetCodeplanAgentSolutionParams(BaseModel): max_depth: Optional[int] = None max_priority: Optional[int] = None + chat_token: str + @app.add_request(method="getCodeplanAgentSolution") @tracer.start_as_current_span("get_codeplan_solution") @@ -466,6 +468,21 @@ def get_codeplan_agent_solution( id: JsonRpcId, params: GetCodeplanAgentSolutionParams, ) -> None: + def simple_chat_message(msg: str) -> None: + app.log.info("simple_chat_message!") + server.send_notification( + method="my_progress", + params={ + "chatToken": params.chat_token, + "kind": "SimpleChatMessage", + "value": { + "message": msg, + }, + }, + ) + + simple_chat_message("Starting!") + try: # create a set of AnalyzerRuleViolations # seed the task manager with these violations @@ -500,31 +517,36 @@ def get_codeplan_agent_solution( if platform.system() == "Windows": uri_path = uri_path.removeprefix("/") - seed_tasks.append( - class_to_use( - file=str(Path(uri_path).absolute()), - line=incident.line_number, - column=-1, # Not contained within report? - message=incident.message, - priority=0, - incident=Incident(**incident.model_dump()), - violation=Violation( - id=incident.violation_name or "", - description=incident.violation_description or "", - category=incident.violation_category, - labels=incident.violation_labels, - ), - ruleset=RuleSet( - name=incident.ruleset_name, - description=incident.ruleset_description or "", - ), - ) + seed_task = class_to_use( + file=str(Path(uri_path).absolute()), + line=incident.line_number, + column=-1, # Not contained within report? + message=incident.message, + priority=0, + incident=Incident(**incident.model_dump()), + violation=Violation( + id=incident.violation_name or "", + description=incident.violation_description or "", + category=incident.violation_category, + labels=incident.violation_labels, + ), + ruleset=RuleSet( + name=incident.ruleset_name, + description=incident.ruleset_description or "", + ), ) + + seed_tasks.append(seed_task) + app.task_manager.set_seed_tasks(*seed_tasks) app.log.info( - f"starting code plan loop with iterations: {params.max_iterations}, max depth: {params.max_depth}, and max priority: {params.max_priority}" + f"Starting code plan loop with iterations: {params.max_iterations}, max depth: {params.max_depth}, and max priority: {params.max_priority}" ) + simple_chat_message( + f"Starting processing with iterations: {params.max_iterations}, max depth: {params.max_depth}, and max priority: {params.max_priority}" + ) + next_task_fn = scoped_task_fn( params.max_iterations, app.task_manager.get_next_task ) @@ -545,8 +567,18 @@ class OverallResult(TypedDict): # get the ignored tasks set initial_ignored_tasks = app.task_manager.ignored_tasks + simple_chat_message("Running validators...") + for task in next_task_fn(params.max_priority, params.max_depth): app.log.debug(f"Executing task {task.__class__.__name__}: {task}") + if hasattr(task, "message"): + simple_chat_message( + f"Executing task {task.__class__.__name__} ({task.message}), from: {task.oldest_ancestor().__class__.__name__}." + ) + else: + simple_chat_message( + f"Executing task {task.__class__.__name__}, from: {task.oldest_ancestor().__class__.__name__}." + ) # get the solved tasks set pre_task_solved_tasks = app.task_manager.processed_tasks @@ -555,7 +587,9 @@ class OverallResult(TypedDict): result = app.task_manager.execute_task(task) - app.log.debug(f"Task {task.__class__.__name__} result: {result}") + app.log.debug(f"Task {task.__class__.__name__}, result: {result}") + # simple_chat_message(f"Got result! Encountered errors: {result.encountered_errors}. Modified files: {result.modified_files}.") + simple_chat_message("Finished task!") app.task_manager.supply_result(result) @@ -613,6 +647,8 @@ class OverallResult(TypedDict): app.log.debug(f"QUEUE_STATE: IGNORED_TASKS: {task}") app.log.debug("QUEUE_STATE: IGNORED_TASKS: END") + simple_chat_message("Running validators...") + # after we have completed all the tasks, we should show what has been accomplished for this particular solution app.log.debug("QUEUE_STATE_END_OF_CODE_PLAN: SUCCESSFUL TASKS: START") for task in app.task_manager.processed_tasks - initial_solved_tasks: @@ -622,11 +658,14 @@ class OverallResult(TypedDict): for task in set(app.task_manager.ignored_tasks) - set(initial_ignored_tasks): app.log.debug(f"QUEUE_STATE_SEED_TASKS: SUCCESSFUL_TASKS: {task}") app.log.debug("QUEUE_STATE_END_OF_CODE_PLAN: IGNORED_TASKS: END") + diff = app.rcm.snapshot.diff(agent_solution_snapshot) overall_result["diff"] = diff[1] + diff[2] app.rcm.reset(agent_solution_snapshot) + simple_chat_message("Finished!") + server.send_response( id=id, result=dict(overall_result), diff --git a/pyproject.toml b/pyproject.toml index e9cc9cef..c6cd154b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,28 +20,28 @@ dependencies = [ "aiohttp==3.8.6; python_version < '3.12'", "aiohttp==3.9.3; python_version >= '3.12'", "gitpython==3.1.43", - "pydantic==2.8.2", + "pydantic==2.10.6", "pydantic-settings==2.4.0", "requests==2.32.3", "pygments==2.18.0", "python-dateutil==2.8.2", - "ibm-generative-ai==2.2.0", "Jinja2==3.1.4", - "langchain==0.3.1", + "langchain==0.3.17", "langchain-community==0.3.1", "langchain-openai==0.3.2", - "langchain-google-genai==2.0.4", - "langchain-aws==0.2.1", + "langchain-ollama==0.2.3", + "langchain-google-genai==2.0.9", + "langchain-aws==0.2.11", "langchain-experimental==0.3.2", "langchain-deepseek-official==0.1.0", "gunicorn==22.0.0", "tree-sitter==0.22.3", "tree-sitter-java==0.21.0", - "sequoia-diff==0.0.8", + "sequoia-diff>=0.0.9", "python-dotenv==1.0.1", "pyyaml==6.0.1", "lxml==5.3.0", - "boto3==1.34.157", # Allows Amazon Bedrock to work + "boto3==1.36.9", # Allows Amazon Bedrock to work "pylspclient==0.1.2", # used for talking to RPC clients over stdin/stdout "opentelemetry-sdk", "opentelemetry-api",