Skip to content

Commit

Permalink
Merge pull request #83 from Yiannis128/v0.3.1
Browse files Browse the repository at this point in the history
V0.3.1
  • Loading branch information
Yiannis128 authored Sep 26, 2023
2 parents 30a4cfd + 97e815b commit b048ebd
Show file tree
Hide file tree
Showing 9 changed files with 169 additions and 20 deletions.
2 changes: 2 additions & 0 deletions config.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"ai_model": "gpt-3.5-turbo-16k",
"ai_custom": {},
"esbmc_path": "./esbmc",
"allow_successful": true,
"esbmc_params": [
"--interval-analysis",
"--goto-unwind",
Expand All @@ -20,6 +21,7 @@
"consecutive_prompt_delay": 20,
"temp_auto_clean": false,
"temp_file_dir": "./temp",
"loading_hints": true,
"chat_modes": {
"user_chat": {
"temperature": 1.0,
Expand Down
24 changes: 18 additions & 6 deletions esbmc_ai_lib/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ def init_commands() -> None:
fix_code_command.on_solution_signal.add_listener(chat.set_solution)
fix_code_command.on_solution_signal.add_listener(update_solution)

optimize_code_command.on_solution_signal.add_listener(chat.set_solution)
optimize_code_command.on_solution_signal.add_listener(update_solution)


def _run_command_mode(
command: ChatCommand,
Expand All @@ -174,14 +177,15 @@ def _run_command_mode(
sys.exit(1)
else:
print(solution)
# elif command == verify_code_command:
# raise NotImplementedError()
elif command == optimize_code_command:
optimize_code_command.execute(
error, solution = optimize_code_command.execute(
file_path=get_main_source_file_path(),
source_code=source_code,
function_names=args,
)

print(solution)
sys.exit(1 if error else 0)
else:
command.execute()
sys.exit(0)
Expand Down Expand Up @@ -302,12 +306,13 @@ def main() -> None:

# ESBMC will output 0 for verification success and 1 for verification
# failed, if anything else gets thrown, it's an ESBMC error.
if exit_code == 0:
if not config.allow_successful and exit_code == 0:
print("Success!")
print(esbmc_output)
sys.exit(0)
elif exit_code != 1:
elif exit_code != 0 and exit_code != 1:
print(f"ESBMC exit code: {exit_code}")
print(f"ESBMC Output:\n\n{esbmc_err_output}")
sys.exit(1)

# Command mode: Check if command is called and call it.
Expand Down Expand Up @@ -396,11 +401,18 @@ def main() -> None:
continue
elif command == optimize_code_command.command_name:
# Optimize Code command
optimize_code_command.execute(
error, solution = optimize_code_command.execute(
file_path=get_main_source_file_path(),
source_code=get_main_source_file().content,
function_names=command_args,
)

if error:
# Print error
print("\n" + solution + "\n")
else:
print(f"\nOptimizations Completed:\n```c\n{solution}```\n")

continue
else:
# Commands without parameters or returns are handled automatically.
Expand Down
5 changes: 4 additions & 1 deletion esbmc_ai_lib/commands/fix_code_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from os import get_terminal_size
from time import sleep
from typing import Tuple
from typing_extensions import override
from langchain.schema import AIMessage, HumanMessage

Expand Down Expand Up @@ -31,7 +32,9 @@ def __init__(self) -> None:
self.anim = create_loading_widget()

@override
def execute(self, file_name: str, source_code: str, esbmc_output: str):
def execute(
self, file_name: str, source_code: str, esbmc_output: str
) -> Tuple[bool, str]:
wait_time: int = int(config.consecutive_prompt_delay)
# Create time left animation to show how much time left between API calls
# This is done by creating a list of all the numbers to be displayed and
Expand Down
33 changes: 22 additions & 11 deletions esbmc_ai_lib/commands/optimize_code_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import sys
from os import get_terminal_size
from typing import Iterable, Optional
from typing import Iterable, Optional, Tuple
from typing_extensions import override
from string import Template
from random import randint
Expand All @@ -13,6 +13,7 @@
from esbmc_ai_lib.frontend.c_types import is_primitive_type
from esbmc_ai_lib.frontend.esbmc_code_generator import ESBMCCodeGenerator
from esbmc_ai_lib.esbmc_util import esbmc_load_source_code
from esbmc_ai_lib.msg_bus import Signal
from esbmc_ai_lib.solution_generator import SolutionGenerator
from .chat_command import ChatCommand
from .. import config
Expand All @@ -31,6 +32,7 @@ def __init__(self) -> None:
command_name="optimize-code",
help_message="(EXPERIMENTAL) Optimizes the code of a specific function or the entire file if a function is not specified. Usage: optimize-code [function_name]",
)
self.on_solution_signal: Signal = Signal()

def _get_functions_list(
self,
Expand Down Expand Up @@ -280,8 +282,19 @@ def get_function_from_collection(

@override
def execute(
self, file_path: str, source_code: str, function_names: list[str]
) -> None:
self,
file_path: str,
source_code: str,
function_names: list[str],
) -> Tuple[bool, str]:
"""Executes the Optimize Code command. The command takes the following inputs:
* file_path: The path of the source code file.
* source_code: The source code file contents.
* function_names: List of function names to optimize. Main is always excluded.
Returns a `Tuple[bool, str]` which is the flag if there was an error, and the
source code from the LLM.
"""
clang_ast: ast.ClangAST = ast.ClangAST(
file_path=file_path,
source_code=source_code,
Expand Down Expand Up @@ -323,7 +336,7 @@ def execute(
function_name=fn_name,
)

new_source_code: str = SolutionGenerator.get_code_from_solution(
optimized_source_code: str = SolutionGenerator.get_code_from_solution(
response.message.content
)

Expand All @@ -335,19 +348,17 @@ def execute(
# Check equivalence
equal: bool = self.check_function_pequivalence(
original_source_code=source_code,
new_source_code=new_source_code,
new_source_code=optimized_source_code,
function_name=fn_name,
)

if equal:
new_source_code = response.message.content
# If equal, then return with explanation.
new_source_code = optimized_source_code
break
elif attempt == max_retries - 1:
print("Failed all attempts...")
return
return True, "Failed all attempts..."
else:
print("Failed attempt", attempt)

print("\nOptimizations Completed:\n")
print(new_source_code)
print()
return False, new_source_code
32 changes: 32 additions & 0 deletions esbmc_ai_lib/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
consecutive_prompt_delay: float = 20.0
ai_model: AIModel = AIModels.GPT_3.value

loading_hints: bool = False
allow_successful: bool = False

cfg_path: str = "./config.json"


Expand Down Expand Up @@ -146,6 +149,21 @@ def _load_config_value(
return default, False


def _load_config_bool(
config_file: dict,
name: str,
default: bool = False,
) -> bool:
value, _ = _load_config_value(config_file, name, default)
if isinstance(value, bool):
return value
else:
raise TypeError(
f"Error: config invalid {name} value: {value} "
+ "Make sure it is a bool value."
)


def _load_config_real_number(
config_file: dict, name: str, default: object = None
) -> Union[int, float]:
Expand Down Expand Up @@ -197,6 +215,20 @@ def load_config(file_path: str) -> None:
temp_file_dir,
)

global allow_successful
allow_successful = _load_config_bool(
config_file,
"allow_successful",
False,
)

global loading_hints
loading_hints = _load_config_bool(
config_file,
"loading_hints",
True,
)

# Load the custom ai configs.
_load_custom_ai(config_file["ai_custom"])

Expand Down
10 changes: 8 additions & 2 deletions esbmc_ai_lib/loading_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
from time import sleep
from itertools import cycle
from threading import Thread
from typing import Optional

from esbmc_ai_lib import config


class LoadingWidget(object):
done: bool = False
thread: Thread
thread: Optional[Thread]
loading_text: str
animation: list[str]
anim_speed: float
Expand Down Expand Up @@ -53,6 +56,8 @@ def _animate(self) -> None:
terminal.flush()

def start(self, text: str = "Please Wait") -> None:
if not config.loading_hints:
return
self.done = False
self.loading_text = text
self.thread = Thread(target=self._animate)
Expand All @@ -62,7 +67,8 @@ def start(self, text: str = "Please Wait") -> None:
def stop(self) -> None:
self.done = True
# Block until end.
self.thread.join()
if self.thread:
self.thread.join()


_widgets: list[LoadingWidget] = []
Expand Down
13 changes: 13 additions & 0 deletions esbmc_ai_lib/user_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,19 @@ def set_solution(self, source_code: str) -> None:
message=AIMessage(content="Understood"), protected=True
)

def set_optimized_solution(self, source_code: str) -> None:
self.solution = source_code
self.push_to_message_stack(
message=HumanMessage(
content=f"Here is the optimized code:\n\n{source_code}"
),
protected=True,
)

self.push_to_message_stack(
message=AIMessage(content="Understood"), protected=True
)

@override
def compress_message_stack(self) -> None:
"""Uses ConversationSummaryMemory from Langchain to summarize the conversation of all the non-protected
Expand Down
26 changes: 26 additions & 0 deletions samples/optimize-code/fact_01.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include <stdio.h>

int factorial(int n)
{
if (n <= 0)
{
return 1;
}
else
{
int result = 1;
while (n > 0)
{
result *= n;
n--;
}
return result;
}
}

int main()
{
int num = 10;
printf("Factorial of %d is %d\n", num, factorial(num));
return 0;
}
44 changes: 44 additions & 0 deletions tests/regtest/test_ast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
source_code = """struct linear
{
int value;
};
typedef struct linear LinearTypeDef;
typedef struct
{
int x;
int y;
} Point;
Point a;
Point *b;
int c;
char *d;
typedef enum Types
{
ONE,
TWO,
THREE
} Typest;
enum Types e = ONE;
Typest f = TWO;
union Combines
{
int a;
int b;
int c;
};
typedef union Combines CombinesTypeDef;
enum extra { A, B, C};
typedef enum extra ExtraEnum;"""
# TODO

0 comments on commit b048ebd

Please sign in to comment.