Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Code Mode: Add message history customization #126

Merged
merged 13 commits into from
Apr 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions config.json
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
"generate_solution": {
"max_attempts": 5,
"temperature": 1.3,
"message_history": "normal",
"scenarios": {
"division by zero": {
"system": [
Expand Down
82 changes: 53 additions & 29 deletions esbmc_ai/commands/fix_code_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import sys
from typing import Any, Tuple
from typing_extensions import override
from langchain.schema import AIMessage, HumanMessage

from esbmc_ai.chat_response import FinishReason
from esbmc_ai.latest_state_solution_generator import LatestStateSolutionGenerator
from esbmc_ai.reverse_order_solution_generator import ReverseOrderSolutionGenerator

from .chat_command import ChatCommand
from .. import config
Expand All @@ -18,8 +19,6 @@
from ..solution_generator import (
ESBMCTimedOutException,
SolutionGenerator,
SourceCodeParseError,
get_esbmc_output_formatted,
)
from ..logging import print_horizontal_line, printv, printvv

Expand Down Expand Up @@ -61,21 +60,58 @@ def print_raw_conversation() -> None:
else "Using generic prompt..."
)

match config.fix_code_message_history:
case "normal":
solution_generator = SolutionGenerator(
ai_model_agent=config.chat_prompt_generator_mode,
ai_model=config.ai_model,
llm=config.ai_model.create_llm(
api_keys=config.api_keys,
temperature=config.chat_prompt_generator_mode.temperature,
requests_max_tries=config.requests_max_tries,
requests_timeout=config.requests_timeout,
),
scenario=scenario,
source_code_format=config.source_code_format,
esbmc_output_type=config.esbmc_output_type,
)
case "latest_only":
solution_generator = LatestStateSolutionGenerator(
ai_model_agent=config.chat_prompt_generator_mode,
ai_model=config.ai_model,
llm=config.ai_model.create_llm(
api_keys=config.api_keys,
temperature=config.chat_prompt_generator_mode.temperature,
requests_max_tries=config.requests_max_tries,
requests_timeout=config.requests_timeout,
),
scenario=scenario,
source_code_format=config.source_code_format,
esbmc_output_type=config.esbmc_output_type,
)
case "reverse":
solution_generator = ReverseOrderSolutionGenerator(
ai_model_agent=config.chat_prompt_generator_mode,
ai_model=config.ai_model,
llm=config.ai_model.create_llm(
api_keys=config.api_keys,
temperature=config.chat_prompt_generator_mode.temperature,
requests_max_tries=config.requests_max_tries,
requests_timeout=config.requests_timeout,
),
scenario=scenario,
source_code_format=config.source_code_format,
esbmc_output_type=config.esbmc_output_type,
)
case _:
raise NotImplementedError(
f"error: {config.fix_code_message_history} has not been implemented in the Fix Code Command"
)

try:
solution_generator = SolutionGenerator(
ai_model_agent=config.chat_prompt_generator_mode,
solution_generator.update_state(
source_code=source_code,
esbmc_output=esbmc_output,
ai_model=config.ai_model,
llm=config.ai_model.create_llm(
api_keys=config.api_keys,
temperature=config.chat_prompt_generator_mode.temperature,
requests_max_tries=config.requests_max_tries,
requests_timeout=config.requests_timeout,
),
scenario=scenario,
source_code_format=config.source_code_format,
esbmc_output_type=config.esbmc_output_type,
)
except ESBMCTimedOutException:
print("error: ESBMC has timed out...")
Expand All @@ -93,9 +129,7 @@ def print_raw_conversation() -> None:
llm_solution, finish_reason = solution_generator.generate_solution()
self.anim.stop()
if finish_reason == FinishReason.length:
self.anim.start("Compressing message stack... Please Wait")
solution_generator.compress_message_stack()
self.anim.stop()
else:
source_code = llm_solution
break
Expand Down Expand Up @@ -135,26 +169,16 @@ def print_raw_conversation() -> None:

return False, source_code

# TODO Move this process into Solution Generator since have (beginning) is done
# inside, and the other half is done here.
# Get formatted ESBMC output
try:
esbmc_output = get_esbmc_output_formatted(
esbmc_output_type=config.esbmc_output_type,
esbmc_output=esbmc_output,
)
except SourceCodeParseError:
pass
# Update state
solution_generator.update_state(source_code, esbmc_output)
except ESBMCTimedOutException:
print("error: ESBMC has timed out...")
sys.exit(1)

# Failure case
print(f"ESBMC-AI Notice: Failure {idx+1}/{max_retries}: Retrying...")

# Update state
solution_generator.update_state(source_code, esbmc_output)

if config.raw_conversation:
print_raw_conversation()

Expand Down
13 changes: 13 additions & 0 deletions esbmc_ai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
source_code_format: str = "full"

fix_code_max_attempts: int = 5
fix_code_message_history: str = ""

requests_max_tries: int = 5
requests_timeout: float = 60
Expand All @@ -57,6 +58,7 @@
cfg_path: str


# TODO Get rid of this class as soon as ConfigTool with the pyautoconfig
class AIAgentConversation(NamedTuple):
"""Immutable class describing the conversation definition for an AI agent. The
class represents the system messages of the AI agent defined and contains a load
Expand Down Expand Up @@ -384,6 +386,17 @@ def load_config(file_path: str) -> None:
f"ESBMC output type in the config is not valid: {esbmc_output_type}"
)

global fix_code_message_history
fix_code_message_history, _ = _load_config_value(
config_file=config_file["chat_modes"]["generate_solution"],
name="message_history",
)

if fix_code_message_history not in ["normal", "latest_only", "reverse"]:
raise ValueError(
f"error: fix code mode message history not valid: {fix_code_message_history}"
)

global requests_max_tries
requests_max_tries = int(
_load_config_real_number(
Expand Down
27 changes: 27 additions & 0 deletions esbmc_ai/latest_state_solution_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Author: Yiannis Charalambous

from typing_extensions import override
from langchain_core.messages import BaseMessage
from esbmc_ai.solution_generator import SolutionGenerator
from esbmc_ai.chat_response import FinishReason

# TODO Test me


class LatestStateSolutionGenerator(SolutionGenerator):
"""SolutionGenerator that only shows the latest source code and verifier
output state."""

@override
def generate_solution(self) -> tuple[str, FinishReason]:
# Backup message stack and clear before sending base message. We want
# to keep the message stack intact because we will print it with
# print_raw_conversation.
messages: list[BaseMessage] = self.messages
self.messages: list[BaseMessage] = []
solution, finish_reason = super().generate_solution()
# Append last messages to the messages stack
messages.extend(self.messages)
# Restore
self.messages = messages
return solution, finish_reason
34 changes: 34 additions & 0 deletions esbmc_ai/reverse_order_solution_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Author: Yiannis Charalambous

from langchain.schema import BaseMessage, HumanMessage
from typing_extensions import override, Optional
from esbmc_ai.solution_generator import (
SolutionGenerator,
get_source_code_formatted,
get_source_code_err_line_idx,
get_clang_err_line_index,
apply_line_patch,
)
from esbmc_ai.chat_response import FinishReason, ChatResponse

# TODO Test me


class ReverseOrderSolutionGenerator(SolutionGenerator):
"""SolutionGenerator that shows the source code and verifier output state in
reverse order."""

@override
def send_message(self, message: Optional[str] = None) -> ChatResponse:
# Reverse the messages
messages: list[BaseMessage] = self.messages.copy()
self.messages.reverse()

response: ChatResponse = super().send_message(message)

# Add to the reversed message the new message received by the LLM.
messages.append(self.messages[-1])
# Restore
self.messages = messages

return response
79 changes: 44 additions & 35 deletions esbmc_ai/solution_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,50 +82,43 @@ def get_esbmc_output_formatted(esbmc_output_type: str, esbmc_output: str) -> str
class SolutionGenerator(BaseChatInterface):
def __init__(
self,
ai_model_agent: DynamicAIModelAgent,
ai_model_agent: DynamicAIModelAgent | ChatPromptSettings,
llm: BaseLanguageModel,
source_code: str,
esbmc_output: str,
ai_model: AIModel,
scenario: str = "",
source_code_format: str = "full",
esbmc_output_type: str = "full",
) -> None:
# Convert to chat prompt
chat_prompt: ChatPromptSettings = DynamicAIModelAgent.to_chat_prompt_settings(
ai_model_agent=ai_model_agent, scenario=scenario
)
"""Initializes the solution generator. This ModelChat provides Dynamic
Prompting. Will get the correct scenario from the DynamicAIModelAgent
supplied and create a ChatPrompt."""

chat_prompt: ChatPromptSettings = ai_model_agent
if isinstance(ai_model_agent, DynamicAIModelAgent):
# Convert to chat prompt
chat_prompt = DynamicAIModelAgent.to_chat_prompt_settings(
ai_model_agent=ai_model_agent, scenario=scenario
)

super().__init__(
ai_model_agent=chat_prompt,
ai_model=ai_model,
llm=llm,
)

self.initial_prompt = ai_model_agent.initial_prompt

self.esbmc_output_type: str = esbmc_output_type
self.source_code_format: str = source_code_format
self.source_code_raw: str = source_code
# Used for resetting state.
self._original_source_code: str = source_code

# Format ESBMC output
try:
self.esbmc_output = get_esbmc_output_formatted(
esbmc_output_type=self.esbmc_output_type,
esbmc_output=esbmc_output,
)
except SourceCodeParseError:
# When clang output is displayed, show it entirely as it doesn't get very
# big.
self.esbmc_output = esbmc_output
self.source_code_raw: Optional[str] = None
self.source_code_formatted: Optional[str] = None
self.esbmc_output: Optional[str] = None

@override
def compress_message_stack(self) -> None:
# Resets the conversation - cannot summarize code
# If generate_solution is called after this point, it will start new
# with the currently set state.
self.messages: list[BaseMessage] = []
self.source_code_raw = self._original_source_code

@classmethod
def get_code_from_solution(cls, solution: str) -> str:
Expand Down Expand Up @@ -153,27 +146,43 @@ def get_code_from_solution(cls, solution: str) -> str:
pass
return solution

def update_state(
self, source_code: Optional[str] = None, esbmc_output: Optional[str] = None
) -> None:
if source_code:
self.source_code_raw = source_code
if esbmc_output:
self.esbmc_output = esbmc_output
def update_state(self, source_code: str, esbmc_output: str) -> None:
"""Updates the latest state of the code and ESBMC output. This should be
called before generate_solution."""
self.source_code_raw = source_code

def generate_solution(self) -> tuple[str, FinishReason]:
self.push_to_message_stack(HumanMessage(content=self.initial_prompt))
# Format ESBMC output
try:
self.esbmc_output = get_esbmc_output_formatted(
esbmc_output_type=self.esbmc_output_type,
esbmc_output=esbmc_output,
)
except SourceCodeParseError:
# When clang output is displayed, show it entirely as it doesn't get very
# big.
self.esbmc_output = esbmc_output

# Format source code
source_code_formatted: str = get_source_code_formatted(
self.source_code_formatted = get_source_code_formatted(
source_code_format=self.source_code_format,
source_code=self.source_code_raw,
source_code=source_code,
esbmc_output=self.esbmc_output,
)

def generate_solution(self) -> tuple[str, FinishReason]:
assert (
self.source_code_raw is not None
and self.source_code_formatted is not None
and self.esbmc_output is not None
), "Call update_state before calling generate_solution."

self.push_to_message_stack(
HumanMessage(content=self.ai_model_agent.initial_prompt)
)

# Apply template substitution to message stack
self.apply_template_value(
source_code=source_code_formatted,
source_code=self.source_code_formatted,
esbmc_output=self.esbmc_output,
)

Expand Down
Loading
Loading