Skip to content

Commit

Permalink
Added esbmc_output_type and source_code_format to the program. Along …
Browse files Browse the repository at this point in the history
…with tests
  • Loading branch information
Yiannis128 committed Mar 28, 2024
1 parent 7baacba commit 35403da
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 43 deletions.
64 changes: 32 additions & 32 deletions esbmc_ai/commands/fix_code_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
from .. import config
from ..msg_bus import Signal
from ..loading_widget import create_loading_widget
from ..esbmc_util import esbmc_load_source_code
from ..solution_generator import SolutionGenerator
from ..esbmc_util import (
esbmc_get_error_type,
esbmc_load_source_code,
)
from ..solution_generator import SolutionGenerator, get_esbmc_output_formatted
from ..logging import printv, printvv

# TODO Remove built in messages and move them to config.
Expand All @@ -28,28 +31,14 @@ def __init__(self) -> None:
)
self.anim = create_loading_widget()

def _resolve_scenario(self, esbmc_output: str) -> str:
# Start search from the marker.
marker: str = "Violated property:\n"
violated_property_index: int = esbmc_output.rfind(marker) + len(marker)
from_loc_error_msg: str = esbmc_output[violated_property_index:]
# Find second new line which contains the location of the violated
# property and that should point to the line with the type of error.
# In this case, the type of error is the "scenario".
scenario_index: int = from_loc_error_msg.find("\n")
scenario: str = from_loc_error_msg[scenario_index + 1 :]
scenario_end_l_index: int = scenario.find("\n")
scenario = scenario[:scenario_end_l_index].strip()
return scenario

@override
def execute(self, **kwargs: Any) -> Tuple[bool, str]:
file_name: str = kwargs["file_name"]
source_code: str = kwargs["source_code"]
esbmc_output: str = kwargs["esbmc_output"]

# Parse the esbmc output here and determine what "Scenario" to use.
scenario: str = self._resolve_scenario(esbmc_output)
scenario: str = esbmc_get_error_type(esbmc_output)

printv(f"Scenario: {scenario}")
printv(
Expand All @@ -58,33 +47,33 @@ def execute(self, **kwargs: Any) -> Tuple[bool, str]:
else "Using generic prompt..."
)

llm = config.ai_model.create_llm(
api_keys=config.api_keys,
temperature=config.chat_prompt_generator_mode.temperature,
requests_max_tries=config.requests_max_tries,
requests_timeout=config.requests_timeout,
)

solution_generator = SolutionGenerator(
ai_model_agent=config.chat_prompt_generator_mode,
source_code=source_code,
esbmc_output=esbmc_output,
ai_model=config.ai_model,
llm=llm,
llm=config.ai_model.create_llm(
api_keys=config.api_keys,
temperature=config.chat_prompt_generator_mode.temperature,
requests_max_tries=config.requests_max_tries,
requests_timeout=config.requests_timeout,
),
scenario=scenario,
source_code_format=config.source_code_format,
esbmc_output_type=config.esbmc_output_type,
)

print()

max_retries: int = 10
max_retries: int = config.fix_code_max_attempts
for idx in range(max_retries):
# Get a response. Use while loop to account for if the message stack
# gets full, then need to compress and retry.
response: str = ""
llm_solution: str = ""
while True:
# Generate AI solution
self.anim.start("Generating Solution... Please Wait")
response, finish_reason = solution_generator.generate_solution()
llm_solution, finish_reason = solution_generator.generate_solution()
self.anim.stop()
if finish_reason == FinishReason.length:
self.anim.start("Compressing message stack... Please Wait")
Expand All @@ -96,7 +85,7 @@ def execute(self, **kwargs: Any) -> Tuple[bool, str]:
# Print verbose lvl 2
printvv("\nGeneration:")
printvv("-" * get_terminal_size().columns)
printvv(response)
printvv(llm_solution)
printvv("-" * get_terminal_size().columns)
printvv("")

Expand All @@ -105,22 +94,33 @@ def execute(self, **kwargs: Any) -> Tuple[bool, str]:
self.anim.start("Verifying with ESBMC... Please Wait")
exit_code, esbmc_output, esbmc_err_output = esbmc_load_source_code(
file_path=file_name,
source_code=str(response),
source_code=llm_solution,
esbmc_params=config.esbmc_params,
auto_clean=config.temp_auto_clean,
timeout=config.verifier_timeout,
)
self.anim.stop()

# TODO Move this process into Solution Generator since have (beginning) is done
# inside, and the other half is done here.
try:
esbmc_output = get_esbmc_output_formatted(
esbmc_output_type=config.esbmc_output_type,
esbmc_output=esbmc_output,
)
except ValueError:
# Probably did not compile and so ESBMC source code is clang output.
pass

# Print verbose lvl 2
printvv("-" * get_terminal_size().columns)
printvv(esbmc_output)
printvv(esbmc_err_output)
printvv("-" * get_terminal_size().columns)

if exit_code == 0:
self.on_solution_signal.emit(response)
return False, response
self.on_solution_signal.emit(llm_solution)
return False, llm_solution

# Failure case
print(f"Failure {idx+1}/{max_retries}: Retrying...")
Expand Down
105 changes: 94 additions & 11 deletions esbmc_ai/solution_generator.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,57 @@
# Author: Yiannis Charalambous 2023

from typing import Optional
from typing_extensions import override
from langchain.base_language import BaseLanguageModel
from langchain.schema import BaseMessage

from esbmc_ai.chat_response import ChatResponse, FinishReason
from esbmc_ai.config import DynamicAIModelAgent
from esbmc_ai.config import ChatPromptSettings, DynamicAIModelAgent
from esbmc_ai.frontend.solution import apply_line_patch

from .ai_models import AIModel
from .base_chat_interface import BaseChatInterface
from esbmc_ai.esbmc_util import (
esbmc_get_counter_example,
esbmc_get_violated_property,
get_source_code_err_line_idx,
)


def get_source_code_formatted(
source_code_format: str, source_code: str, esbmc_output: str
) -> str:
match source_code_format:
case "single":
line: Optional[int] = get_source_code_err_line_idx(esbmc_output)
assert line, "error line not found in esbmc output"
# ESBMC reports errors starting from 1. To get the correct line, we need to use 0 based
# indexing.
return source_code.splitlines(True)[line]
case "full":
return source_code
case _:
raise ValueError(
f"Not a valid format for source code: {source_code_format}"
)


def get_esbmc_output_formatted(esbmc_output_type: str, esbmc_output: str) -> str:
match esbmc_output_type:
case "vp":
value: Optional[str] = esbmc_get_violated_property(esbmc_output)
if not value:
raise ValueError("Not found violated property.")
return value
case "ce":
value: Optional[str] = esbmc_get_counter_example(esbmc_output)
if not value:
raise ValueError("Not found counterexample.")
return value
case "full":
return esbmc_output
case _:
raise ValueError(f"Not a valid ESBMC output type: {esbmc_output_type}")


class SolutionGenerator(BaseChatInterface):
Expand All @@ -20,17 +63,34 @@ def __init__(
esbmc_output: str,
ai_model: AIModel,
scenario: str = "",
source_code_format: str = "full",
esbmc_output_type: str = "full",
) -> None:
# Convert to chat prompt
chat_prompt: ChatPromptSettings = DynamicAIModelAgent.to_chat_prompt_settings(
ai_model_agent=ai_model_agent, scenario=scenario
)

super().__init__(
ai_model_agent=DynamicAIModelAgent.to_chat_prompt_settings(
ai_model_agent=ai_model_agent, scenario=scenario
),
ai_model_agent=chat_prompt,
ai_model=ai_model,
llm=llm,
)
self.initial_prompt = ai_model_agent.initial_prompt
self.source_code = source_code
self.esbmc_output = esbmc_output

self.esbmc_output_type: str = esbmc_output_type
self.esbmc_output = get_esbmc_output_formatted(
esbmc_output_type=self.esbmc_output_type,
esbmc_output=esbmc_output,
)

self.source_code_format: str = source_code_format
self.source_code_raw: str = source_code
self.source_code = get_source_code_formatted(
source_code_format=self.source_code_format,
source_code=self.source_code_raw,
esbmc_output=self.esbmc_output,
)

self.set_template_value("source_code", self.source_code)
self.set_template_value("esbmc_output", self.esbmc_output)
Expand All @@ -46,20 +106,43 @@ def get_code_from_solution(cls, solution: str) -> str:
will generate text and formatting despite being told not to."""
try:
code_start: int = solution.index("```") + 3
assert code_start != -1

# Remove up until the new line, because usually there's a language
# specification after the 3 ticks ```c...
code_start = solution.index("\n", code_start)
code_end: int = len(solution) - 3 - solution[::-1].index("```")
code_start = solution.index("\n", code_start) + 1
assert code_start != -1

code_end: int = solution[::-1].index("```")
assert code_start != -1

# -4 = 3 backticks and also the \n before the backticks.
code_end: int = len(solution) - 4 - code_end
assert code_start <= code_end

solution = solution[code_start:code_end]
except ValueError:
except (ValueError, AssertionError):
pass
finally:
return solution
return solution

def generate_solution(self) -> tuple[str, FinishReason]:
response: ChatResponse = self.send_message(self.initial_prompt)
solution: str = str(response.message.content)

solution = SolutionGenerator.get_code_from_solution(solution)

# If source code passed to LLM is formatted then we need to recombine to
# full source code before giving to ESBMC
match self.source_code_format:
case "single":
err_line: Optional[int] = get_source_code_err_line_idx(
self.esbmc_output
)
assert (
err_line
), "fix code command: error line could not be found to apply brutal patch replacement"
solution = apply_line_patch(
self.source_code_raw, solution, err_line, err_line
)

return solution, response.finish_reason
22 changes: 22 additions & 0 deletions tests/test_solution_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Author: Yiannis Charalambous

from esbmc_ai.solution_generator import SolutionGenerator


def test_get_code_from_solution():
assert (
SolutionGenerator.get_code_from_solution(
"This is a code block:\n\n```c\naaa\n```"
)
== "aaa"
)
assert (
SolutionGenerator.get_code_from_solution(
"This is a code block:\n\n```\nabc\n```"
)
== "abc"
)
assert (
SolutionGenerator.get_code_from_solution("This is a code block:```abc\n```")
== "This is a code block:```abc\n```"
)

0 comments on commit 35403da

Please sign in to comment.