diff --git a/esbmc_ai_lib/commands/optimize_code_command.py b/esbmc_ai_lib/commands/optimize_code_command.py index c533904..601e09e 100644 --- a/esbmc_ai_lib/commands/optimize_code_command.py +++ b/esbmc_ai_lib/commands/optimize_code_command.py @@ -3,14 +3,14 @@ import os import sys from os import get_terminal_size -from typing import Iterable, Optional, Tuple +from typing import Optional, Tuple from typing_extensions import override from string import Template from random import randint from esbmc_ai_lib.chat_response import json_to_base_messages from esbmc_ai_lib.frontend.ast_decl import Declaration, TypeDeclaration -from esbmc_ai_lib.frontend.c_types import is_primitive_type +from esbmc_ai_lib.frontend.c_types import is_primitive_type, get_base_type from esbmc_ai_lib.frontend.esbmc_code_generator import ESBMCCodeGenerator from esbmc_ai_lib.esbmc_util import esbmc_load_source_code from esbmc_ai_lib.msg_bus import Signal @@ -20,7 +20,7 @@ from ..base_chat_interface import ChatResponse from ..optimize_code import OptimizeCode from ..frontend import ast -from ..frontend.ast import FunctionDeclaration +from ..frontend.ast import ClangAST, FunctionDeclaration from ..logging import printvv @@ -34,6 +34,10 @@ def __init__(self) -> None: ) self.on_solution_signal: Signal = Signal() + @staticmethod + def _generate_param_name(d: Declaration) -> str: + return f"param_{d.type_name.replace(' ', '_').replace('*', 'ptr')}_{randint(a=0, b=99999)}" + def _get_functions_list( self, clang_ast: ast.ClangAST, @@ -58,6 +62,62 @@ def _get_functions_list( function_names = all_function_names.copy() return function_names + def _generate_primitive_type_variables( + self, + elements: list[Declaration], + code_gen: ESBMCCodeGenerator, + max_depth: int = 10, + _current_depth: int = -1, + ) -> tuple[list[str], list[str]]: + """Generates the initialization values for a list of Declarations. + Starts by traversing the parameters of a function depth first, as soon + as a primitive type is reached, a variable is initialized using + __VERIFIER_nondet_X() as the value. + The names of the variables, along with the declaration statement is returned.""" + + ast: ClangAST = code_gen.ast + + names: list[str] = [] + statements: list[str] = [] + + _current_depth += 1 + + for arg in elements: + assert arg.cursor, f"Element {elements} has no valid cursor." + + if is_primitive_type(arg): + name: str = OptimizeCodeCommand._generate_param_name(arg) + statement = code_gen.statement_primitive_construct( + d=arg, + assign_to=name, + init=True, + ) + names.append(name) + statements.append(statement) + else: + arg_type: Optional[ + TypeDeclaration + ] = ast._get_type_declaration_from_cursor(arg.cursor) + assert arg_type, f"Assert failed: {arg}" + + if _current_depth < max_depth: + # Traverse tree + n, s = self._generate_primitive_type_variables( + elements=arg_type.elements, + code_gen=code_gen, + max_depth=max_depth, + _current_depth=_current_depth, + ) + + names.extend(n) + statements.extend(s) + else: + # Generate NULL + names.append("NULL") + statements.append("") + + return names, statements + def _build_comparison_script( self, old_ast: ast.ClangAST, @@ -110,52 +170,39 @@ def rename_append_declaration(ast: ast.ClangAST, append: str): code_gen_old: ESBMCCodeGenerator = ESBMCCodeGenerator(old_ast) code_gen_new: ESBMCCodeGenerator = ESBMCCodeGenerator(new_ast) - # This Callable is used to record the primitive types that the parameters of each - # function take. They will be injected into the main method and supplied to - # both old and new function calls. Since they have the same function, they will - # generate the same code. - def primitive_assignemnt_old(decl: Declaration) -> str: - name: str = ( - f"param_{decl.type_name.replace(' ', '')}_{randint(a=0, b=99999)}" - ) - statement: str = code_gen_old.statement_primitive_construct( - d=decl, - assign_to=name, - init=True, - ) - decl_statements.append(statement) - decl_statement_names.append(name) - - return name - - def primitive_assignemnt_new(_: Declaration) -> str: - return decl_statement_names.pop(0) - # Generate parameters using same function call of __VERIFIER_nondet_X for primitives. # The parameters will be inlined into the function. + + # Variable names of primitive values. + primitive_vars_names: list[str] = [] # Declaration statements for the variables - decl_statements: list[str] = [] - # Name of the variables - decl_statement_names: list[str] = [] - # Params for each function + statements: list[str] = [] + primitive_vars_names, statements = self._generate_primitive_type_variables( + elements=old_function.args, + code_gen=code_gen_old, + max_depth=config.ocm_init_max_depth, + ) + + # Function parameters. fn_params_old: list[str] = [] fn_params_new: list[str] = [] + + # Generate function parameter values. + param_idx: int = 0 for arg_old, arg_new in zip(old_function.args, new_function.args): # Convert to type assert arg_old.cursor and arg_new.cursor - if is_primitive_type(arg_old.type_name): - name: str = f"param_{arg_old.type_name.replace(' ', '')}_{randint(a=0, b=99999)}" - statement = code_gen_old.statement_primitive_construct( - d=arg_old, - assign_to=name, - init=True, - ) - fn_params_old.append(name) - fn_params_new.append(name) - - decl_statements.append(statement) + # If it's a primitive type, then no need to call the code generator, + # can simply add name of primitive value. + if is_primitive_type(arg_old): + fn_params_old.append(primitive_vars_names[param_idx]) + fn_params_new.append(primitive_vars_names[param_idx]) + param_idx += 1 else: + if arg_old.is_pointer_type(): + pass + arg_old_type: Optional[ TypeDeclaration ] = old_ast._get_type_declaration_from_cursor(arg_old.cursor) @@ -165,20 +212,28 @@ def primitive_assignemnt_new(_: Declaration) -> str: assert ( arg_old_type and arg_new_type ), f"Assert failed: {arg_old} or {arg_new}" + # Create statements and save them. statement_old = code_gen_old.statement_type_construct( - d=arg_old_type, + d_type=arg_old_type, + init_type="ptr" if arg_old.is_pointer_type() else "value", init=False, - primitive_assignment_fn=primitive_assignemnt_old, + primitive_assignment_fn=lambda _: primitive_vars_names[param_idx], + max_depth=config.ocm_init_max_depth, ) fn_params_old.append(statement_old) + statement_new = code_gen_new.statement_type_construct( - d=arg_new_type, + d_type=arg_new_type, + init_type="ptr" if arg_new.is_pointer_type() else "value", init=False, - primitive_assignment_fn=primitive_assignemnt_new, + primitive_assignment_fn=lambda _: primitive_vars_names[param_idx], + max_depth=config.ocm_init_max_depth, ) fn_params_new.append(statement_new) + param_idx += 1 + # Generate the function call & arguments old_params_src: str = code_gen_old.statement_function_call( fn=old_function, @@ -193,10 +248,12 @@ def primitive_assignemnt_new(_: Declaration) -> str: init=True, ) + # TODO Dereference fn result if pointer. + script: str = template.substitute( function_old=old_ast.source_code, function_new=new_ast.source_code, - parameters_list="\n".join(decl_statements), + parameters_list="\n".join(statements), function_call_old=old_params_src, function_call_new=new_params_src, function_assert_old="old_fn_result", @@ -226,8 +283,6 @@ def check_function_pequivalence( ) # Get list of function types. - # TODO: Need to double check that set needs to be used instead of list here - # as the list may return duplicates. old_functions: list[FunctionDeclaration] = original_ast.get_fn_decl() new_functions: list[FunctionDeclaration] = new_ast.get_fn_decl() diff --git a/esbmc_ai_lib/frontend/esbmc_code_generator.py b/esbmc_ai_lib/frontend/esbmc_code_generator.py index ebb090c..f0a9424 100644 --- a/esbmc_ai_lib/frontend/esbmc_code_generator.py +++ b/esbmc_ai_lib/frontend/esbmc_code_generator.py @@ -1,12 +1,14 @@ # Author: Yiannis Charalambous -from typing import Callable, Optional +from typing import Callable, Optional, Literal import clang.cindex as cindex +import esbmc_ai_lib.config as config + from .ast import ClangAST from .ast_decl import Declaration, TypeDeclaration, FunctionDeclaration -from .c_types import is_primitive_type +from .c_types import is_primitive_type, get_base_type """Note about how `assign_to` and `init` work: @@ -78,7 +80,20 @@ def statement_primitive_construct( from https://github.com/esbmc/esbmc/blob/master/src/clang-c-frontend/clang_c_language.cpp """ - value: str = ESBMCCodeGenerator._primitives_base_defaults[d.type_name] + base_type_name: str = get_base_type(d.type_name) + base_value: str = ESBMCCodeGenerator._primitives_base_defaults[base_type_name] + + type_name: str = base_type_name + value: str = base_value + if d.is_pointer_type(): + # No difference if adding [] or *. Pointer means initialize continuous memory. + type_name += "*" + value = ( + f"({type_name})" + + "{" + + ",".join([base_value] * config.ocm_array_expansion) + + "}" + ) if assign_to == None: return value @@ -86,29 +101,44 @@ def statement_primitive_construct( if assign_to == "": return value + ";" else: - return (d.type_name + " " if init else "") + f"{assign_to} = {value};" + return (type_name + " " if init else "") + f"{assign_to} = {value};" def statement_type_construct( self, - d: TypeDeclaration, + d_type: TypeDeclaration, + init_type: Literal["value", "ptr"], assign_to: Optional[str] = None, init: bool = False, primitive_assignment_fn: Optional[Callable[[Declaration], str]] = None, + max_depth: int = 5, + _current_depth: int = -1, ) -> str: - """Constructs a statement to that is represented by decleration d.""" - assert d.cursor + """Constructs a statement that is represented by decleration d. Need ptr information + since TypeDeclaration does not carry such info (Declaration) base type does.""" + assert d_type.cursor + + _current_depth += 1 cmd: str # Check if primitive type and return nondet value. - if len(list(d.cursor.get_children())) == 0: - return self.statement_primitive_construct(d) + if len(list(d_type.cursor.get_children())) == 0: + return self.statement_primitive_construct(d_type) - cmd = "(" + (d.type_name if d.is_typedef() else f"struct {d.name}") + "){" + if d_type.is_typedef(): + cmd = f"({d_type.type_name})" + "{" + raise NotImplementedError("Typedefs not implemented...") + else: + if init_type == "value": + cmd = f"({d_type.construct_type} {d_type.name})" + "{" + elif init_type == "ptr": + cmd = f"({d_type.construct_type} {d_type.name}*)" + "{" + else: + raise ValueError(f"init_type has an invalid value: {init_type}") # Loop through each element of the data type. elements: list[str] = [] - for element in d.elements: + for element in d_type.elements: # Check if element is a primitive type. If not, it will need to be # further broken. if is_primitive_type(element): @@ -133,22 +163,34 @@ def statement_type_construct( type_declaration != None ), f"Reference for type {element} could not be found" - # Get decleration in AST - element_code: str = self.statement_type_construct(type_declaration) - elements.append(element_code) + # Check if max depth is reached. If it has, then do not init pointer + # elements. + if _current_depth < max_depth or not element.is_pointer_type(): + # Get decleration in AST + element_code: str = self.statement_type_construct( + d_type=type_declaration, + init_type="ptr" if element.is_pointer_type() else "value", + max_depth=max_depth, + _current_depth=_current_depth, + primitive_assignment_fn=primitive_assignment_fn, + ) + elements.append(element_code) + else: + elements.append("NULL") + # Join the elements of the type initialization. cmd += ",".join(elements) + "}" # If this construction should be an assignment call. if assign_to != None: if assign_to == "": cmd = cmd + ";" - elif d.type_name != "void": + elif d_type.type_name != "void": # Check if assignment variable should be initialized. if init: - assign_type: str = d.type_name - if d.type_name == "": - assign_type = f"{d.construct_type} {d.name}" + assign_type: str = d_type.type_name + if d_type.type_name == "": + assign_type = f"{d_type.construct_type} {d_type.name}" cmd = f"{assign_type} {assign_to} = {cmd};" else: @@ -188,6 +230,8 @@ def statement_function_call( underlying_type: cindex.Type = arg.cursor.type.get_canonical() underlying_cursor: cindex.Cursor = underlying_type.get_declaration() + arg_decl: Declaration = Declaration.from_cursor(underlying_cursor) + arg_type: Optional[ TypeDeclaration ] = self.ast._get_type_declaration_from_cursor(underlying_cursor) @@ -199,7 +243,12 @@ def statement_function_call( + " " + arg.type_name ) - arg_cmds.append(self.statement_type_construct(d=arg_type)) + arg_cmds.append( + self.statement_type_construct( + d_type=arg_type, + init_type="ptr" if arg_decl.is_pointer_type() else "value", + ) + ) cmd += ", ".join(arg_cmds) + ")"