Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 Fix dependency agent back and forth chat for not-so-good models #490

Merged
merged 4 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion kai/reactive_codeplanner/agent/dependency_agent/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ class FQDNResponse:
version: str

def to_llm_message(self) -> HumanMessage:
return HumanMessage(f"the result is {json.dumps(self.__dict__)}")
return HumanMessage(
f"The result for FQDN search is {json.dumps(self.__dict__)}"
)

def to_xml_element(self) -> ET._Element:
parent = ET.Element(MAVEN_DEPENDENCY_XML_KEY)
Expand Down
62 changes: 42 additions & 20 deletions kai/reactive_codeplanner/agent/dependency_agent/dependency_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Callable, Optional, TypedDict, Union

from langchain.prompts.chat import HumanMessagePromptTemplate
from langchain_core.messages import SystemMessage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage

from kai.llm_interfacing.model_provider import ModelProvider
from kai.logging.logging import get_logger
Expand Down Expand Up @@ -131,7 +131,7 @@ class MavenDependencyAgent(Agent):

Final Answer:
Updated the guava to the commons-collections4 dependency
"""
"""
)

inst_msg_template = HumanMessagePromptTemplate.from_template(
Expand Down Expand Up @@ -172,10 +172,10 @@ def __init__(
self,
model_provider: ModelProvider,
project_base: Path,
retries: int = 1,
retries: int = 3,
) -> None:
self._model_provider = model_provider
self._retries = retries
self._max_retries = retries
self.child_agent = FQDNDependencySelectorAgent(model_provider=model_provider)
self.agent_methods.update({"find_in_pom._run": find_in_pom(project_base)})

Expand All @@ -199,35 +199,49 @@ def execute(self, ask: AgentRequest) -> AgentResult:
# no result. In this case we will want the sub agent to try and give us additional information.
# Today, if we don't have the FQDN then we are going to skip updating for now.

while fix_gen_attempts < self._retries:
all_actions: list[_action] = []
while fix_gen_attempts < self._max_retries:
fix_gen_attempts += 1

fix_gen_response = self._model_provider.invoke(msg)
llm_response = self.parse_llm_response(fix_gen_response.content)
# Break out of the while loop, if we don't have a final answer then we need to retry
if llm_response is None or not llm_response.final_answer:

# if we don't have a final answer, we need to retry, otherwise break
if llm_response is not None and llm_response.final_answer:
all_actions.extend(llm_response.actions)
break

# We do not believe that we should not continue now we have to continue after running the code that is asked to be run.
# The only exception to this rule, is when we actually update the file, that should be handled by the caller.
# This happens sometimes that the LLM will stop and wait for more information.
# we have to keep the chat going until we get a final answer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we have some limit here, say 5-8 round trips? if it is more, then something is not correct, and we need to exit?

Copy link
Contributor Author

@pranavgaikwad pranavgaikwad Nov 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I set the limit to 3 but can increase

msg.append(AIMessage(content=fix_gen_response.content))

if llm_response is None or not llm_response.actions:
msg.append(
HumanMessage(
content="Please provide a complete response until at least one action to perform."
)
)
continue

all_actions.extend(llm_response.actions)
tool_outputs: list[str] = []
for action in llm_response.actions:
for method_name, method in self.agent_methods.items():
if method_name in action.code:
if callable(method):
method_out = method(action.code)
to_llm_message = getattr(method_out, "to_llm_message", None)
if callable(to_llm_message):
msg.append(method_out.to_llm_message())
to_llm_message: Optional[Callable[[], HumanMessage]] = (
getattr(method_out, "to_llm_message", None)
)
if to_llm_message is not None and callable(to_llm_message):
tool_outputs.append(method_out.to_llm_message().content)

self._retries += 1
msg.append(HumanMessage(content="\n".join(tool_outputs)))

if llm_response is None or fix_gen_response is None:
return AgentResult()

if not maven_search:
for a in llm_response.actions:
for a in all_actions:
if "search_fqdn.run" in a.code:
logger.debug("running search for FQDN")
_search_fqdn: Callable[
Expand All @@ -254,7 +268,7 @@ def execute(self, ask: AgentRequest) -> AgentResult:
maven_search = result

if not find_pom_lines:
for a in llm_response.actions:
for a in all_actions:
if "find_in_pom._run" in a.code:
logger.debug("running find in pom")
_find_in_pom: Optional[Callable[[str], FindInPomResponse]] = (
Expand Down Expand Up @@ -303,7 +317,7 @@ def parse_llm_response(
parts = line.split(":")

if len(parts) > 1:
match parts[0]:
match parts[0].strip():
case "Thought":
s = ":".join(parts[1:])
if code_block or observation_str:
Expand All @@ -328,9 +342,13 @@ def parse_llm_response(
in_observation = True
continue
case "Final Answer":
actions.append(
_action(code_block, thought_str, observation_str)
)
if code_block:
actions.append(
_action(code_block, thought_str, observation_str)
)
code_block = ""
thought_str = ""
observation_str = ""
in_final_answer = True
in_code = False
in_thought = False
Expand All @@ -357,4 +375,8 @@ def parse_llm_response(
observation_str = "\n".join([observation_str, line]).strip()
else:
observation_str = line.strip()

if code_block and thought_str and observation_str:
actions.append(_action(code_block, thought_str, observation_str))

return _llm_response(actions, final_answer)
19 changes: 9 additions & 10 deletions kai/reactive_codeplanner/task_runner/dependency/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def execute_task(self, rcm: RepoContextManager, task: Task) -> TaskResult:
)
return TaskResult(encountered_errors=[], modified_files=[])

if not maven_dep_response.fqdn_response or not maven_dep_response.find_in_pom:
if not maven_dep_response.fqdn_response:
logger.info(
"we got a final answer, but it must have skipped steps in the LLM, we need to review the LLM call resposne %r",
maven_dep_response,
Expand All @@ -84,8 +84,6 @@ def execute_task(self, rcm: RepoContextManager, task: Task) -> TaskResult:
"we are now updating the pom based %s", maven_dep_response.final_answer
)
pom = os.path.join(os.path.join(rcm.project_root, "pom.xml"))
# Needed to remove ns0:
ET.register_namespace("", "http://maven.apache.org/POM/4.0.0")
tree = ET.parse(pom) # trunk-ignore(bandit/B320)
if tree is None:
return TaskResult(modified_files=[], encountered_errors=[])
Expand All @@ -100,14 +98,15 @@ def execute_task(self, rcm: RepoContextManager, task: Task) -> TaskResult:
## We always need to add the new dep
deps.append(maven_dep_response.fqdn_response.to_xml_element())

if maven_dep_response.find_in_pom.override:
## we know we need to remove this dep
for dep in deps:
if maven_dep_response.find_in_pom.match_dep(dep):
logger.debug("found dep %r and removing", dep)
deps.remove(dep)
if maven_dep_response.find_in_pom is not None:
if maven_dep_response.find_in_pom.override:
## we know we need to remove this dep
for dep in deps:
if maven_dep_response.find_in_pom.match_dep(dep):
logger.debug("found dep %r and removing", dep)
deps.remove(dep)

tree.write(pom, "utf-8", pretty_print=True)
tree.write(file=pom, encoding="utf-8", pretty_print=True)
rcm.commit(
f"DependencyTaskRunner changed file {str(pom)}",
None,
Expand Down
33 changes: 33 additions & 0 deletions tests/test_data/test_dependency_task_runner/Order.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@

package com.redhat.coolstore.model;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;

import jakarta.persistence.CascadeType;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.FetchType;
import jakarta.persistence.GeneratedValue;
import jakarta.persistence.GenerationType;
import jakarta.persistence.Id;
import jakarta.persistence.JoinColumn;
import jakarta.persistence.OneToMany;
import jakarta.persistence.SequenceGenerator;
import jakarta.persistence.Table;

@Entity
@Table(name = "ORDERS")
public class Order implements Serializable {

private static final long serialVersionUID = -1L;

@Id
@GeneratedValue(strategy = GenerationType.SEQUENCE, generator = "order_seq")
@SequenceGenerator(name = "order_seq", sequenceName = "order_seq")
private long orderId;

// ... (rest of the code remains the same)

}
65 changes: 65 additions & 0 deletions tests/test_data/test_dependency_task_runner/expected_pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.redhat.coolstore</groupId>
<artifactId>monolith</artifactId>
<version>1.0.0-SNAPSHOT</version>
<packaging>war</packaging>
<name>coolstore-monolith</name>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.build.timestamp.format>yyyyMMdd'T'HHmmss</maven.build.timestamp.format>
<project.encoding>UTF-8</project.encoding>
<maven.test.skip>true</maven.test.skip>
</properties>
<dependencies>
<dependency>
<groupId>javax</groupId>
<artifactId>javaee-web-api</artifactId>
<version>7.0</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>javax</groupId>
<artifactId>javaee-api</artifactId>
<version>7.0</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.jboss.spec.javax.jms</groupId>
<artifactId>jboss-jms-api_2.0_spec</artifactId>
<version>2.0.0.Final</version>
</dependency>
<dependency>
<groupId>org.flywaydb</groupId>
<artifactId>flyway-core</artifactId>
<version>4.1.2</version>
</dependency>
<dependency>
<groupId>org.jboss.spec.javax.rmi</groupId>
<artifactId>jboss-rmi-api_1.0_spec</artifactId>
<version>1.0.2.Final</version>
</dependency>
<dependency><artifactId>jakarta.persistence-api</artifactId><groupId>jakarta.persistence</groupId><version>3.2.0</version></dependency></dependencies>
<build>
<finalName>ROOT</finalName>
<plugins>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.0</version>
<configuration>
<encoding>${project.encoding}</encoding>
<source>1.8</source>
<target>1.8</target>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-war-plugin</artifactId>
<version>3.2.0</version>
</plugin>
</plugins>
</build>
<profiles>
<!-- TODO: Add OpenShift profile here -->
</profiles>
</project>
68 changes: 68 additions & 0 deletions tests/test_data/test_dependency_task_runner/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
<?xml version="1.0" encoding="UTF-8"?>
<project
xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.redhat.coolstore</groupId>
<artifactId>monolith</artifactId>
<version>1.0.0-SNAPSHOT</version>
<packaging>war</packaging>
<name>coolstore-monolith</name>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.build.timestamp.format>yyyyMMdd'T'HHmmss</maven.build.timestamp.format>
<project.encoding>UTF-8</project.encoding>
<maven.test.skip>true</maven.test.skip>
</properties>
<dependencies>
<dependency>
<groupId>javax</groupId>
<artifactId>javaee-web-api</artifactId>
<version>7.0</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>javax</groupId>
<artifactId>javaee-api</artifactId>
<version>7.0</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.jboss.spec.javax.jms</groupId>
<artifactId>jboss-jms-api_2.0_spec</artifactId>
<version>2.0.0.Final</version>
</dependency>
<dependency>
<groupId>org.flywaydb</groupId>
<artifactId>flyway-core</artifactId>
<version>4.1.2</version>
</dependency>
<dependency>
<groupId>org.jboss.spec.javax.rmi</groupId>
<artifactId>jboss-rmi-api_1.0_spec</artifactId>
<version>1.0.2.Final</version>
</dependency>
</dependencies>
<build>
<finalName>ROOT</finalName>
<plugins>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.0</version>
<configuration>
<encoding>${project.encoding}</encoding>
<source>1.8</source>
<target>1.8</target>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-war-plugin</artifactId>
<version>3.2.0</version>
</plugin>
</plugins>
</build>
<profiles>
<!-- TODO: Add OpenShift profile here -->
</profiles>
</project>
Loading
Loading