diff --git a/kai/reactive_codeplanner/agent/dependency_agent/api.py b/kai/reactive_codeplanner/agent/dependency_agent/api.py
index 44ece76f..cae03cc2 100644
--- a/kai/reactive_codeplanner/agent/dependency_agent/api.py
+++ b/kai/reactive_codeplanner/agent/dependency_agent/api.py
@@ -17,7 +17,9 @@ class FQDNResponse:
version: str
def to_llm_message(self) -> HumanMessage:
- return HumanMessage(f"the result is {json.dumps(self.__dict__)}")
+ return HumanMessage(
+ f"The result for FQDN search is {json.dumps(self.__dict__)}"
+ )
def to_xml_element(self) -> ET._Element:
parent = ET.Element(MAVEN_DEPENDENCY_XML_KEY)
diff --git a/kai/reactive_codeplanner/agent/dependency_agent/dependency_agent.py b/kai/reactive_codeplanner/agent/dependency_agent/dependency_agent.py
index 83f12168..6787ec8a 100644
--- a/kai/reactive_codeplanner/agent/dependency_agent/dependency_agent.py
+++ b/kai/reactive_codeplanner/agent/dependency_agent/dependency_agent.py
@@ -3,7 +3,7 @@
from typing import Any, Callable, Optional, TypedDict, Union
from langchain.prompts.chat import HumanMessagePromptTemplate
-from langchain_core.messages import SystemMessage
+from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from kai.llm_interfacing.model_provider import ModelProvider
from kai.logging.logging import get_logger
@@ -131,7 +131,7 @@ class MavenDependencyAgent(Agent):
Final Answer:
Updated the guava to the commons-collections4 dependency
- """
+"""
)
inst_msg_template = HumanMessagePromptTemplate.from_template(
@@ -172,10 +172,10 @@ def __init__(
self,
model_provider: ModelProvider,
project_base: Path,
- retries: int = 1,
+ retries: int = 3,
) -> None:
self._model_provider = model_provider
- self._retries = retries
+ self._max_retries = retries
self.child_agent = FQDNDependencySelectorAgent(model_provider=model_provider)
self.agent_methods.update({"find_in_pom._run": find_in_pom(project_base)})
@@ -199,35 +199,49 @@ def execute(self, ask: AgentRequest) -> AgentResult:
# no result. In this case we will want the sub agent to try and give us additional information.
# Today, if we don't have the FQDN then we are going to skip updating for now.
- while fix_gen_attempts < self._retries:
+ all_actions: list[_action] = []
+ while fix_gen_attempts < self._max_retries:
fix_gen_attempts += 1
fix_gen_response = self._model_provider.invoke(msg)
llm_response = self.parse_llm_response(fix_gen_response.content)
- # Break out of the while loop, if we don't have a final answer then we need to retry
- if llm_response is None or not llm_response.final_answer:
+
+ # if we don't have a final answer, we need to retry, otherwise break
+ if llm_response is not None and llm_response.final_answer:
+ all_actions.extend(llm_response.actions)
break
- # We do not believe that we should not continue now we have to continue after running the code that is asked to be run.
- # The only exception to this rule, is when we actually update the file, that should be handled by the caller.
- # This happens sometimes that the LLM will stop and wait for more information.
+ # we have to keep the chat going until we get a final answer
+ msg.append(AIMessage(content=fix_gen_response.content))
+
+ if llm_response is None or not llm_response.actions:
+ msg.append(
+ HumanMessage(
+ content="Please provide a complete response until at least one action to perform."
+ )
+ )
+ continue
+ all_actions.extend(llm_response.actions)
+ tool_outputs: list[str] = []
for action in llm_response.actions:
for method_name, method in self.agent_methods.items():
if method_name in action.code:
if callable(method):
method_out = method(action.code)
- to_llm_message = getattr(method_out, "to_llm_message", None)
- if callable(to_llm_message):
- msg.append(method_out.to_llm_message())
+ to_llm_message: Optional[Callable[[], HumanMessage]] = (
+ getattr(method_out, "to_llm_message", None)
+ )
+ if to_llm_message is not None and callable(to_llm_message):
+ tool_outputs.append(method_out.to_llm_message().content)
- self._retries += 1
+ msg.append(HumanMessage(content="\n".join(tool_outputs)))
if llm_response is None or fix_gen_response is None:
return AgentResult()
if not maven_search:
- for a in llm_response.actions:
+ for a in all_actions:
if "search_fqdn.run" in a.code:
logger.debug("running search for FQDN")
_search_fqdn: Callable[
@@ -254,7 +268,7 @@ def execute(self, ask: AgentRequest) -> AgentResult:
maven_search = result
if not find_pom_lines:
- for a in llm_response.actions:
+ for a in all_actions:
if "find_in_pom._run" in a.code:
logger.debug("running find in pom")
_find_in_pom: Optional[Callable[[str], FindInPomResponse]] = (
@@ -303,7 +317,7 @@ def parse_llm_response(
parts = line.split(":")
if len(parts) > 1:
- match parts[0]:
+ match parts[0].strip():
case "Thought":
s = ":".join(parts[1:])
if code_block or observation_str:
@@ -328,9 +342,13 @@ def parse_llm_response(
in_observation = True
continue
case "Final Answer":
- actions.append(
- _action(code_block, thought_str, observation_str)
- )
+ if code_block:
+ actions.append(
+ _action(code_block, thought_str, observation_str)
+ )
+ code_block = ""
+ thought_str = ""
+ observation_str = ""
in_final_answer = True
in_code = False
in_thought = False
@@ -357,4 +375,8 @@ def parse_llm_response(
observation_str = "\n".join([observation_str, line]).strip()
else:
observation_str = line.strip()
+
+ if code_block and thought_str and observation_str:
+ actions.append(_action(code_block, thought_str, observation_str))
+
return _llm_response(actions, final_answer)
diff --git a/kai/reactive_codeplanner/task_runner/dependency/task_runner.py b/kai/reactive_codeplanner/task_runner/dependency/task_runner.py
index 7926461e..ed18d277 100644
--- a/kai/reactive_codeplanner/task_runner/dependency/task_runner.py
+++ b/kai/reactive_codeplanner/task_runner/dependency/task_runner.py
@@ -73,7 +73,7 @@ def execute_task(self, rcm: RepoContextManager, task: Task) -> TaskResult:
)
return TaskResult(encountered_errors=[], modified_files=[])
- if not maven_dep_response.fqdn_response or not maven_dep_response.find_in_pom:
+ if not maven_dep_response.fqdn_response:
logger.info(
"we got a final answer, but it must have skipped steps in the LLM, we need to review the LLM call resposne %r",
maven_dep_response,
@@ -84,8 +84,6 @@ def execute_task(self, rcm: RepoContextManager, task: Task) -> TaskResult:
"we are now updating the pom based %s", maven_dep_response.final_answer
)
pom = os.path.join(os.path.join(rcm.project_root, "pom.xml"))
- # Needed to remove ns0:
- ET.register_namespace("", "http://maven.apache.org/POM/4.0.0")
tree = ET.parse(pom) # trunk-ignore(bandit/B320)
if tree is None:
return TaskResult(modified_files=[], encountered_errors=[])
@@ -100,14 +98,15 @@ def execute_task(self, rcm: RepoContextManager, task: Task) -> TaskResult:
## We always need to add the new dep
deps.append(maven_dep_response.fqdn_response.to_xml_element())
- if maven_dep_response.find_in_pom.override:
- ## we know we need to remove this dep
- for dep in deps:
- if maven_dep_response.find_in_pom.match_dep(dep):
- logger.debug("found dep %r and removing", dep)
- deps.remove(dep)
+ if maven_dep_response.find_in_pom is not None:
+ if maven_dep_response.find_in_pom.override:
+ ## we know we need to remove this dep
+ for dep in deps:
+ if maven_dep_response.find_in_pom.match_dep(dep):
+ logger.debug("found dep %r and removing", dep)
+ deps.remove(dep)
- tree.write(pom, "utf-8", pretty_print=True)
+ tree.write(file=pom, encoding="utf-8", pretty_print=True)
rcm.commit(
f"DependencyTaskRunner changed file {str(pom)}",
None,
diff --git a/tests/test_data/test_dependency_task_runner/Order.java b/tests/test_data/test_dependency_task_runner/Order.java
new file mode 100644
index 00000000..7fc6ff22
--- /dev/null
+++ b/tests/test_data/test_dependency_task_runner/Order.java
@@ -0,0 +1,33 @@
+
+package com.redhat.coolstore.model;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+
+import jakarta.persistence.CascadeType;
+import jakarta.persistence.Column;
+import jakarta.persistence.Entity;
+import jakarta.persistence.FetchType;
+import jakarta.persistence.GeneratedValue;
+import jakarta.persistence.GenerationType;
+import jakarta.persistence.Id;
+import jakarta.persistence.JoinColumn;
+import jakarta.persistence.OneToMany;
+import jakarta.persistence.SequenceGenerator;
+import jakarta.persistence.Table;
+
+@Entity
+@Table(name = "ORDERS")
+public class Order implements Serializable {
+
+ private static final long serialVersionUID = -1L;
+
+ @Id
+ @GeneratedValue(strategy = GenerationType.SEQUENCE, generator = "order_seq")
+ @SequenceGenerator(name = "order_seq", sequenceName = "order_seq")
+ private long orderId;
+
+ // ... (rest of the code remains the same)
+
+}
diff --git a/tests/test_data/test_dependency_task_runner/expected_pom.xml b/tests/test_data/test_dependency_task_runner/expected_pom.xml
new file mode 100644
index 00000000..1eabd481
--- /dev/null
+++ b/tests/test_data/test_dependency_task_runner/expected_pom.xml
@@ -0,0 +1,65 @@
+
+ 4.0.0
+ com.redhat.coolstore
+ monolith
+ 1.0.0-SNAPSHOT
+ war
+ coolstore-monolith
+
+ UTF-8
+ yyyyMMdd'T'HHmmss
+ UTF-8
+ true
+
+
+
+ javax
+ javaee-web-api
+ 7.0
+ provided
+
+
+ javax
+ javaee-api
+ 7.0
+ provided
+
+
+ org.jboss.spec.javax.jms
+ jboss-jms-api_2.0_spec
+ 2.0.0.Final
+
+
+ org.flywaydb
+ flyway-core
+ 4.1.2
+
+
+ org.jboss.spec.javax.rmi
+ jboss-rmi-api_1.0_spec
+ 1.0.2.Final
+
+ jakarta.persistence-apijakarta.persistence3.2.0
+
+ ROOT
+
+
+ maven-compiler-plugin
+ 3.0
+
+ ${project.encoding}
+ 1.8
+ 1.8
+
+
+
+ org.apache.maven.plugins
+ maven-war-plugin
+ 3.2.0
+
+
+
+
+
+
+
diff --git a/tests/test_data/test_dependency_task_runner/pom.xml b/tests/test_data/test_dependency_task_runner/pom.xml
new file mode 100644
index 00000000..0e760b82
--- /dev/null
+++ b/tests/test_data/test_dependency_task_runner/pom.xml
@@ -0,0 +1,68 @@
+
+
+ 4.0.0
+ com.redhat.coolstore
+ monolith
+ 1.0.0-SNAPSHOT
+ war
+ coolstore-monolith
+
+ UTF-8
+ yyyyMMdd'T'HHmmss
+ UTF-8
+ true
+
+
+
+ javax
+ javaee-web-api
+ 7.0
+ provided
+
+
+ javax
+ javaee-api
+ 7.0
+ provided
+
+
+ org.jboss.spec.javax.jms
+ jboss-jms-api_2.0_spec
+ 2.0.0.Final
+
+
+ org.flywaydb
+ flyway-core
+ 4.1.2
+
+
+ org.jboss.spec.javax.rmi
+ jboss-rmi-api_1.0_spec
+ 1.0.2.Final
+
+
+
+ ROOT
+
+
+ maven-compiler-plugin
+ 3.0
+
+ ${project.encoding}
+ 1.8
+ 1.8
+
+
+
+ org.apache.maven.plugins
+ maven-war-plugin
+ 3.2.0
+
+
+
+
+
+
+
diff --git a/tests/test_dependency_task_runner.py b/tests/test_dependency_task_runner.py
new file mode 100644
index 00000000..282d3e6b
--- /dev/null
+++ b/tests/test_dependency_task_runner.py
@@ -0,0 +1,75 @@
+import unittest
+from pathlib import Path
+
+from kai.kai_config import KaiConfigModels
+from kai.llm_interfacing.model_provider import ModelProvider
+from kai.reactive_codeplanner.agent.dependency_agent.dependency_agent import (
+ MavenDependencyAgent,
+)
+from kai.reactive_codeplanner.task_runner.compiler.maven_validator import (
+ PackageDoesNotExistError,
+)
+from kai.reactive_codeplanner.task_runner.dependency.task_runner import (
+ DependencyTaskRunner,
+)
+from kai.reactive_codeplanner.vfs.git_vfs import RepoContextManager
+
+
+class TestDependencyTaskRunner(unittest.TestCase):
+
+ def _task_runner(
+ self, project_base: Path, response_variant: int = 0
+ ) -> DependencyTaskRunner:
+ responses = [
+ [
+ '\n\n\n\n\n\nThought: The error message indicates that the Maven compiler is unable to find the jakarta.persistence package. This package is typically provided by the Jakarta Persistence API (JPA) dependency. I need to find the fully qualified domain name (FQDN) of the JPA dependency to add it to the pom.xml file.\n\nAction: \n```python\nresult = search_fqdn.run(artifact_id="jakarta.persistence-api", group_id="jakarta.persistence")\n```\nObservation: We now have the FQDN for the Jakarta Persistence API dependency.',
+ '\n\n\n\n\n\nThought: Now that I have the FQDN for the Jakarta Persistence API dependency, I need to find the location in the pom.xml file where I can add this dependency.\n\nAction: \n```python\nstart_line, end_line = find_in_pom._run(relative_file_path="pom.xml", keywords=[""])\n```\nObservation: We now have the start and end line of the dependencies block in the pom.xml file.\n\nUser: The result for the location is (12, 13) \n\nAssistant: \n\n\n\n\n\nThought: Now that I have the start and end line of the dependencies block in the pom.xml file, I can add the Jakarta Persistence API dependency.\n\nAction: \n```python\nxml = f"{result.group_id}{result.artifact_id}{result.version}"\nresult = editor._run(relative_file_path="pom.xml", start_line=12, end_line=13, patch=xml)\nprint(result)\n```\nObservation: The pom.xml file is now updated with the Jakarta Persistence API dependency added.\n\nFinal Answer:\nAdded the Jakarta Persistence API dependency to the pom.xml file. Steps taken:\n1. Searched for the FQDN of the Jakarta Persistence API dependency.\n2. Found the location of the dependencies block in the pom.xml file.\n3. Added the Jakarta Persistence API dependency to the pom.xml file.',
+ ],
+ ]
+ return DependencyTaskRunner(
+ MavenDependencyAgent(
+ model_provider=ModelProvider(
+ config=KaiConfigModels(
+ args={
+ "responses": responses[response_variant],
+ "sleep": None,
+ },
+ provider="FakeListChatModel",
+ )
+ ),
+ project_base=project_base,
+ )
+ )
+
+ def test_package_does_not_exist_task(self) -> None:
+ project_base = Path(
+ ".", "tests", "test_data", "test_dependency_task_runner"
+ ).absolute()
+
+ task = PackageDoesNotExistError(
+ priority=1,
+ parse_lines="'[ERROR] ./test_data/test_dependency_agent/Order.java:[8,27] package jakarta.persistence does not exist'",
+ missing_package="jakarta.persistence",
+ message="package jakarta.persistence does not exist",
+ file=str(project_base / "Order.java"),
+ line=8,
+ max_retries=3,
+ column=27,
+ )
+
+ runner = self._task_runner(project_base=project_base, response_variant=0)
+
+ rcm = RepoContextManager(project_root=project_base)
+ result = runner.execute_task(rcm=rcm, task=task)
+
+ self.assertEqual(len(result.modified_files), 1)
+
+ if result.modified_files:
+ modified_pom = result.modified_files[0]
+ with open(modified_pom) as f:
+ actual_pom_contents = f.read()
+ with open(project_base / "expected_pom.xml") as f:
+ expected_pom_contents = f.read()
+ self.assertEqual(actual_pom_contents, expected_pom_contents)
+
+ rcm.reset_to_first()