Skip to content

Commit

Permalink
OCM Command Update: Refactored and improved code.
Browse files Browse the repository at this point in the history
* Added pointer and continuous memory support.
* Added init depth limit. Pointers get initialized to NULL after limit.
* Separated logic.
* _generate_primitive_type_variables generates the nondet variables.
* The OCM function then passes again through the types and
builds the code.

ESBMC Code Gen Update:

* Added support for pointers and continuous memory.
  • Loading branch information
Yiannis128 committed Oct 3, 2023
1 parent 60c795e commit 8eabea5
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 65 deletions.
147 changes: 101 additions & 46 deletions esbmc_ai_lib/commands/optimize_code_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import os
import sys
from os import get_terminal_size
from typing import Iterable, Optional, Tuple
from typing import Optional, Tuple
from typing_extensions import override
from string import Template
from random import randint

from esbmc_ai_lib.chat_response import json_to_base_messages
from esbmc_ai_lib.frontend.ast_decl import Declaration, TypeDeclaration
from esbmc_ai_lib.frontend.c_types import is_primitive_type
from esbmc_ai_lib.frontend.c_types import is_primitive_type, get_base_type
from esbmc_ai_lib.frontend.esbmc_code_generator import ESBMCCodeGenerator
from esbmc_ai_lib.esbmc_util import esbmc_load_source_code
from esbmc_ai_lib.msg_bus import Signal
Expand All @@ -20,7 +20,7 @@
from ..base_chat_interface import ChatResponse
from ..optimize_code import OptimizeCode
from ..frontend import ast
from ..frontend.ast import FunctionDeclaration
from ..frontend.ast import ClangAST, FunctionDeclaration
from ..logging import printvv


Expand All @@ -34,6 +34,10 @@ def __init__(self) -> None:
)
self.on_solution_signal: Signal = Signal()

@staticmethod
def _generate_param_name(d: Declaration) -> str:
return f"param_{d.type_name.replace(' ', '_').replace('*', 'ptr')}_{randint(a=0, b=99999)}"

def _get_functions_list(
self,
clang_ast: ast.ClangAST,
Expand All @@ -58,6 +62,62 @@ def _get_functions_list(
function_names = all_function_names.copy()
return function_names

def _generate_primitive_type_variables(
self,
elements: list[Declaration],
code_gen: ESBMCCodeGenerator,
max_depth: int = 10,
_current_depth: int = -1,
) -> tuple[list[str], list[str]]:
"""Generates the initialization values for a list of Declarations.
Starts by traversing the parameters of a function depth first, as soon
as a primitive type is reached, a variable is initialized using
__VERIFIER_nondet_X() as the value.
The names of the variables, along with the declaration statement is returned."""

ast: ClangAST = code_gen.ast

names: list[str] = []
statements: list[str] = []

_current_depth += 1

for arg in elements:
assert arg.cursor, f"Element {elements} has no valid cursor."

if is_primitive_type(arg):
name: str = OptimizeCodeCommand._generate_param_name(arg)
statement = code_gen.statement_primitive_construct(
d=arg,
assign_to=name,
init=True,
)
names.append(name)
statements.append(statement)
else:
arg_type: Optional[
TypeDeclaration
] = ast._get_type_declaration_from_cursor(arg.cursor)
assert arg_type, f"Assert failed: {arg}"

if _current_depth < max_depth:
# Traverse tree
n, s = self._generate_primitive_type_variables(
elements=arg_type.elements,
code_gen=code_gen,
max_depth=max_depth,
_current_depth=_current_depth,
)

names.extend(n)
statements.extend(s)
else:
# Generate NULL
names.append("NULL")
statements.append("")

return names, statements

def _build_comparison_script(
self,
old_ast: ast.ClangAST,
Expand Down Expand Up @@ -110,52 +170,39 @@ def rename_append_declaration(ast: ast.ClangAST, append: str):
code_gen_old: ESBMCCodeGenerator = ESBMCCodeGenerator(old_ast)
code_gen_new: ESBMCCodeGenerator = ESBMCCodeGenerator(new_ast)

# This Callable is used to record the primitive types that the parameters of each
# function take. They will be injected into the main method and supplied to
# both old and new function calls. Since they have the same function, they will
# generate the same code.
def primitive_assignemnt_old(decl: Declaration) -> str:
name: str = (
f"param_{decl.type_name.replace(' ', '')}_{randint(a=0, b=99999)}"
)
statement: str = code_gen_old.statement_primitive_construct(
d=decl,
assign_to=name,
init=True,
)
decl_statements.append(statement)
decl_statement_names.append(name)

return name

def primitive_assignemnt_new(_: Declaration) -> str:
return decl_statement_names.pop(0)

# Generate parameters using same function call of __VERIFIER_nondet_X for primitives.
# The parameters will be inlined into the function.

# Variable names of primitive values.
primitive_vars_names: list[str] = []
# Declaration statements for the variables
decl_statements: list[str] = []
# Name of the variables
decl_statement_names: list[str] = []
# Params for each function
statements: list[str] = []
primitive_vars_names, statements = self._generate_primitive_type_variables(
elements=old_function.args,
code_gen=code_gen_old,
max_depth=config.ocm_init_max_depth,
)

# Function parameters.
fn_params_old: list[str] = []
fn_params_new: list[str] = []

# Generate function parameter values.
param_idx: int = 0
for arg_old, arg_new in zip(old_function.args, new_function.args):
# Convert to type
assert arg_old.cursor and arg_new.cursor

if is_primitive_type(arg_old.type_name):
name: str = f"param_{arg_old.type_name.replace(' ', '')}_{randint(a=0, b=99999)}"
statement = code_gen_old.statement_primitive_construct(
d=arg_old,
assign_to=name,
init=True,
)
fn_params_old.append(name)
fn_params_new.append(name)

decl_statements.append(statement)
# If it's a primitive type, then no need to call the code generator,
# can simply add name of primitive value.
if is_primitive_type(arg_old):
fn_params_old.append(primitive_vars_names[param_idx])
fn_params_new.append(primitive_vars_names[param_idx])
param_idx += 1
else:
if arg_old.is_pointer_type():
pass

arg_old_type: Optional[
TypeDeclaration
] = old_ast._get_type_declaration_from_cursor(arg_old.cursor)
Expand All @@ -165,20 +212,28 @@ def primitive_assignemnt_new(_: Declaration) -> str:
assert (
arg_old_type and arg_new_type
), f"Assert failed: {arg_old} or {arg_new}"

# Create statements and save them.
statement_old = code_gen_old.statement_type_construct(
d=arg_old_type,
d_type=arg_old_type,
init_type="ptr" if arg_old.is_pointer_type() else "value",
init=False,
primitive_assignment_fn=primitive_assignemnt_old,
primitive_assignment_fn=lambda _: primitive_vars_names[param_idx],
max_depth=config.ocm_init_max_depth,
)
fn_params_old.append(statement_old)

statement_new = code_gen_new.statement_type_construct(
d=arg_new_type,
d_type=arg_new_type,
init_type="ptr" if arg_new.is_pointer_type() else "value",
init=False,
primitive_assignment_fn=primitive_assignemnt_new,
primitive_assignment_fn=lambda _: primitive_vars_names[param_idx],
max_depth=config.ocm_init_max_depth,
)
fn_params_new.append(statement_new)

param_idx += 1

# Generate the function call & arguments
old_params_src: str = code_gen_old.statement_function_call(
fn=old_function,
Expand All @@ -193,10 +248,12 @@ def primitive_assignemnt_new(_: Declaration) -> str:
init=True,
)

# TODO Dereference fn result if pointer.

script: str = template.substitute(
function_old=old_ast.source_code,
function_new=new_ast.source_code,
parameters_list="\n".join(decl_statements),
parameters_list="\n".join(statements),
function_call_old=old_params_src,
function_call_new=new_params_src,
function_assert_old="old_fn_result",
Expand Down Expand Up @@ -226,8 +283,6 @@ def check_function_pequivalence(
)

# Get list of function types.
# TODO: Need to double check that set needs to be used instead of list here
# as the list may return duplicates.
old_functions: list[FunctionDeclaration] = original_ast.get_fn_decl()
new_functions: list[FunctionDeclaration] = new_ast.get_fn_decl()

Expand Down
87 changes: 68 additions & 19 deletions esbmc_ai_lib/frontend/esbmc_code_generator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Author: Yiannis Charalambous

from typing import Callable, Optional
from typing import Callable, Optional, Literal

import clang.cindex as cindex

import esbmc_ai_lib.config as config

from .ast import ClangAST
from .ast_decl import Declaration, TypeDeclaration, FunctionDeclaration
from .c_types import is_primitive_type
from .c_types import is_primitive_type, get_base_type


"""Note about how `assign_to` and `init` work:
Expand Down Expand Up @@ -78,37 +80,65 @@ def statement_primitive_construct(
from https://github.com/esbmc/esbmc/blob/master/src/clang-c-frontend/clang_c_language.cpp
"""

value: str = ESBMCCodeGenerator._primitives_base_defaults[d.type_name]
base_type_name: str = get_base_type(d.type_name)
base_value: str = ESBMCCodeGenerator._primitives_base_defaults[base_type_name]

type_name: str = base_type_name
value: str = base_value
if d.is_pointer_type():
# No difference if adding [] or *. Pointer means initialize continuous memory.
type_name += "*"
value = (
f"({type_name})"
+ "{"
+ ",".join([base_value] * config.ocm_array_expansion)
+ "}"
)

if assign_to == None:
return value
else:
if assign_to == "":
return value + ";"
else:
return (d.type_name + " " if init else "") + f"{assign_to} = {value};"
return (type_name + " " if init else "") + f"{assign_to} = {value};"

def statement_type_construct(
self,
d: TypeDeclaration,
d_type: TypeDeclaration,
init_type: Literal["value", "ptr"],
assign_to: Optional[str] = None,
init: bool = False,
primitive_assignment_fn: Optional[Callable[[Declaration], str]] = None,
max_depth: int = 5,
_current_depth: int = -1,
) -> str:
"""Constructs a statement to that is represented by decleration d."""
assert d.cursor
"""Constructs a statement that is represented by decleration d. Need ptr information
since TypeDeclaration does not carry such info (Declaration) base type does."""
assert d_type.cursor

_current_depth += 1

cmd: str

# Check if primitive type and return nondet value.
if len(list(d.cursor.get_children())) == 0:
return self.statement_primitive_construct(d)
if len(list(d_type.cursor.get_children())) == 0:
return self.statement_primitive_construct(d_type)

cmd = "(" + (d.type_name if d.is_typedef() else f"struct {d.name}") + "){"
if d_type.is_typedef():
cmd = f"({d_type.type_name})" + "{"
raise NotImplementedError("Typedefs not implemented...")
else:
if init_type == "value":
cmd = f"({d_type.construct_type} {d_type.name})" + "{"
elif init_type == "ptr":
cmd = f"({d_type.construct_type} {d_type.name}*)" + "{"
else:
raise ValueError(f"init_type has an invalid value: {init_type}")

# Loop through each element of the data type.
elements: list[str] = []
for element in d.elements:
for element in d_type.elements:
# Check if element is a primitive type. If not, it will need to be
# further broken.
if is_primitive_type(element):
Expand All @@ -133,22 +163,34 @@ def statement_type_construct(
type_declaration != None
), f"Reference for type {element} could not be found"

# Get decleration in AST
element_code: str = self.statement_type_construct(type_declaration)
elements.append(element_code)
# Check if max depth is reached. If it has, then do not init pointer
# elements.
if _current_depth < max_depth or not element.is_pointer_type():
# Get decleration in AST
element_code: str = self.statement_type_construct(
d_type=type_declaration,
init_type="ptr" if element.is_pointer_type() else "value",
max_depth=max_depth,
_current_depth=_current_depth,
primitive_assignment_fn=primitive_assignment_fn,
)
elements.append(element_code)
else:
elements.append("NULL")

# Join the elements of the type initialization.
cmd += ",".join(elements) + "}"

# If this construction should be an assignment call.
if assign_to != None:
if assign_to == "":
cmd = cmd + ";"
elif d.type_name != "void":
elif d_type.type_name != "void":
# Check if assignment variable should be initialized.
if init:
assign_type: str = d.type_name
if d.type_name == "":
assign_type = f"{d.construct_type} {d.name}"
assign_type: str = d_type.type_name
if d_type.type_name == "":
assign_type = f"{d_type.construct_type} {d_type.name}"

cmd = f"{assign_type} {assign_to} = {cmd};"
else:
Expand Down Expand Up @@ -188,6 +230,8 @@ def statement_function_call(
underlying_type: cindex.Type = arg.cursor.type.get_canonical()
underlying_cursor: cindex.Cursor = underlying_type.get_declaration()

arg_decl: Declaration = Declaration.from_cursor(underlying_cursor)

arg_type: Optional[
TypeDeclaration
] = self.ast._get_type_declaration_from_cursor(underlying_cursor)
Expand All @@ -199,7 +243,12 @@ def statement_function_call(
+ " "
+ arg.type_name
)
arg_cmds.append(self.statement_type_construct(d=arg_type))
arg_cmds.append(
self.statement_type_construct(
d_type=arg_type,
init_type="ptr" if arg_decl.is_pointer_type() else "value",
)
)

cmd += ", ".join(arg_cmds) + ")"

Expand Down

0 comments on commit 8eabea5

Please sign in to comment.