Skip to content

Commit

Permalink
Move backend object files to be with their backends (#510)
Browse files Browse the repository at this point in the history
* Move backend object files to be with their backends

* Improve lint

* Improve lint more

* Fix tests
  • Loading branch information
twizmwazin authored Sep 23, 2024
1 parent 44c488b commit 0cf49e8
Show file tree
Hide file tree
Showing 32 changed files with 892 additions and 972 deletions.
15 changes: 1 addition & 14 deletions claripy/ast/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Iterator

from claripy import Backend
from claripy.annotation import Annotation
from claripy.backends import Backend

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -519,8 +519,6 @@ def _check_args_same(self, other_args: tuple[ArgType, ...], lenient_names=False)
"""
Check if two ASTs are the same.
"""
from claripy.vsa.strided_interval import StridedInterval # pylint:disable=import-outside-toplevel

# Several types inside of args don't support normall == comparison, so if we see those,
# we need compare them manually.
for a, b in zip(self.args, other_args, strict=True):
Expand All @@ -535,17 +533,6 @@ def _check_args_same(self, other_args: tuple[ArgType, ...], lenient_names=False)
continue
if a != b:
return False
if (
isinstance(a, StridedInterval)
and isinstance(b, StridedInterval)
and (
a.bits != b.bits
or a.lower_bound != b.lower_bound
or a.upper_bound != b.upper_bound
or a.stride != b.stride
)
):
return False
if lenient_names and isinstance(a, str) and isinstance(b, str):
continue
if a != b:
Expand Down
10 changes: 5 additions & 5 deletions claripy/ast/bv.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,13 @@ def identical(self, other: Self, strict=False) -> bool:
return super().identical(other, strict)


def BVS(
def BVS( # pylint:disable=redefined-builtin
name,
size,
min=None,
max=None,
stride=None,
uninitialized=False, # pylint:disable=redefined-builtin
uninitialized=False,
explicit_name=None,
discrete_set=False,
discrete_set_max_card=None,
Expand Down Expand Up @@ -312,7 +312,7 @@ def SI(
):
name = "unnamed" if name is None else name
if to_conv is not None:
si = claripy.vsa.CreateStridedInterval(
si = claripy.backends.backend_vsa.CreateStridedInterval(
name=name, bits=bits, lower_bound=lower_bound, upper_bound=upper_bound, stride=stride, to_conv=to_conv
)
return BVS(
Expand Down Expand Up @@ -352,7 +352,7 @@ def ValueSet(bits, region=None, region_base_addr=None, value=None, name=None, va
if isinstance(v, numbers.Number):
min_v, max_v = v, v
stride = 0
elif isinstance(v, claripy.vsa.StridedInterval):
elif isinstance(v, claripy.backends.backend_vsa.StridedInterval):
min_v, max_v = v.lower_bound, v.upper_bound
stride = v.stride
elif isinstance(v, claripy.ast.Base):
Expand Down Expand Up @@ -381,7 +381,7 @@ def DSIS(
name=None, bits=0, lower_bound=None, upper_bound=None, stride=None, explicit_name=None, to_conv=None, max_card=None
):
if to_conv is not None:
si = claripy.vsa.CreateStridedInterval(bits=to_conv.size(), to_conv=to_conv)
si = claripy.backends.backend_vsa.CreateStridedInterval(bits=to_conv.size(), to_conv=to_conv)
return SI(
name=name,
bits=si._bits,
Expand Down
2 changes: 1 addition & 1 deletion claripy/ast/bv.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, overload

from claripy.bv import BVV as ConcreteBVV
from claripy.backends.backend_concrete.bv import BVV as ConcreteBVV
from claripy.fp import RM, FSort

from .bits import Bits
Expand Down
38 changes: 15 additions & 23 deletions claripy/ast/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import struct

from claripy import fp, operations
from claripy import operations
from claripy.ast.base import _make_name
from claripy.fp import FSORT_FLOAT
from claripy.fp import FSORT_FLOAT, RM, FSort

from .bits import Bits
from .bool import Bool
Expand Down Expand Up @@ -33,7 +33,7 @@ def to_fp(self, sort, rm=None):
:return: An FP AST
"""
if rm is None:
rm = fp.RM.default()
rm = RM.default()

return fpToFP(rm, self, sort)

Expand Down Expand Up @@ -65,14 +65,14 @@ def val_to_bv(self, size, signed=True, rm=None):
:return: A bitvector whose value is the rounded version of this FP's value
"""
if rm is None:
rm = fp.RM.default()
rm = RM.default()

op = fpToSBV if signed else fpToUBV
return op(rm, self, size)

@property
def sort(self):
return fp.FSort.from_size(self.length)
return FSort.from_size(self.length)

@staticmethod
def _from_float(like, value):
Expand Down Expand Up @@ -109,7 +109,7 @@ def FPV(value, sort):
elif not isinstance(value, float):
raise TypeError("Must instanciate FPV with a numerical value")

if not isinstance(sort, fp.FSort):
if not isinstance(sort, FSort):
raise TypeError("Must instanciate FPV with a FSort")

if sort == FSORT_FLOAT:
Expand All @@ -126,19 +126,19 @@ def FPV(value, sort):


def _fp_length_calc(a1, a2, a3=None):
if isinstance(a1, fp.RM) and a3 is None:
if isinstance(a1, RM) and a3 is None:
raise Exception
if a3 is None:
return a2.length
return a3.length


fpToFP = operations.op("fpToFP", object, FP, calc_length=_fp_length_calc)
fpToFPUnsigned = operations.op("fpToFPUnsigned", (fp.RM, BV, fp.FSort), FP, calc_length=_fp_length_calc)
fpToFPUnsigned = operations.op("fpToFPUnsigned", (RM, BV, FSort), FP, calc_length=_fp_length_calc)
fpFP = operations.op("fpFP", (BV, BV, BV), FP, calc_length=lambda a, b, c: a.length + b.length + c.length)
fpToIEEEBV = operations.op("fpToIEEEBV", (FP,), BV, calc_length=lambda fp: fp.length)
fpToSBV = operations.op("fpToSBV", (fp.RM, FP, int), BV, calc_length=lambda _rm, _fp, len: len)
fpToUBV = operations.op("fpToUBV", (fp.RM, FP, int), BV, calc_length=lambda _rm, _fp, len: len)
fpToSBV = operations.op("fpToSBV", (RM, FP, int), BV, calc_length=lambda _rm, _fp, len: len)
fpToUBV = operations.op("fpToUBV", (RM, FP, int), BV, calc_length=lambda _rm, _fp, len: len)

#
# unbound float point comparisons
Expand Down Expand Up @@ -172,19 +172,11 @@ def _fp_binop_length(rm, a, b): # pylint:disable=unused-argument

fpAbs = operations.op("fpAbs", (FP,), FP, calc_length=lambda x: x.length)
fpNeg = operations.op("fpNeg", (FP,), FP, calc_length=lambda x: x.length)
fpSub = operations.op("fpSub", (fp.RM, FP, FP), FP, extra_check=_fp_binop_check, calc_length=_fp_binop_length)
fpAdd = operations.op("fpAdd", (fp.RM, FP, FP), FP, extra_check=_fp_binop_check, calc_length=_fp_binop_length)
fpMul = operations.op("fpMul", (fp.RM, FP, FP), FP, extra_check=_fp_binop_check, calc_length=_fp_binop_length)
fpDiv = operations.op("fpDiv", (fp.RM, FP, FP), FP, extra_check=_fp_binop_check, calc_length=_fp_binop_length)
fpSqrt = operations.op(
"fpSqrt",
(
fp.RM,
FP,
),
FP,
calc_length=lambda _, x: x.length,
)
fpSub = operations.op("fpSub", (RM, FP, FP), FP, extra_check=_fp_binop_check, calc_length=_fp_binop_length)
fpAdd = operations.op("fpAdd", (RM, FP, FP), FP, extra_check=_fp_binop_check, calc_length=_fp_binop_length)
fpMul = operations.op("fpMul", (RM, FP, FP), FP, extra_check=_fp_binop_check, calc_length=_fp_binop_length)
fpDiv = operations.op("fpDiv", (RM, FP, FP), FP, extra_check=_fp_binop_check, calc_length=_fp_binop_length)
fpSqrt = operations.op("fpSqrt", (RM, FP), FP, calc_length=lambda _, x: x.length)

#
# bound fp operations
Expand Down
2 changes: 2 additions & 0 deletions claripy/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from .backend import Backend
from .backend_concrete import BackendConcrete
from .backend_vsa import BackendVSA
from .backend_z3 import BackendZ3
Expand All @@ -12,6 +13,7 @@
backends_by_type = {b.__class__.__name__: b for b in all_backends}

__all__ = (
"Backend",
"BackendZ3",
"BackendConcrete",
"BackendVSA",
Expand Down
2 changes: 1 addition & 1 deletion claripy/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ def cardinality(self, a):
"""
return self._cardinality(self.convert(a))

def _cardinality(self, b): # pylint:disable=no-self-use,unused-argument
def _cardinality(self, a): # pylint:disable=no-self-use,unused-argument
"""
This should return the maximum number of values that an expression can take on. This should be a strict
*over* approximation.
Expand Down
8 changes: 8 additions & 0 deletions claripy/backends/backend_concrete/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from __future__ import annotations

from .backend_concrete import BackendConcrete
from .bv import BVV
from .fp import FPV
from .strings import StringV

__all__ = ("BackendConcrete", "BVV", "FPV", "StringV")
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,22 @@
import operator
from functools import reduce

from claripy import bv, fp, strings
from claripy.ast import Base
from claripy.ast.bool import Bool, BoolV
from claripy.ast.bv import BV, BVV
from claripy.ast.fp import FPV
from claripy.ast.strings import StringV
from claripy.backends.backend import Backend
from claripy.backends.backend_concrete import bv, fp, strings
from claripy.errors import BackendError, UnsatError
from claripy.operations import backend_fp_operations, backend_operations, backend_strings_operations

log = logging.getLogger(__name__)


# pylint: disable=too-many-positional-arguments


class BackendConcrete(Backend):
__slots__ = ()

Expand Down Expand Up @@ -153,7 +156,7 @@ def _abstract(self, e): # pylint:disable=no-self-use
return StringV(e.value)
raise BackendError(f"Couldn't abstract object of type {type(e)}")

def _cardinality(self, b):
def _cardinality(self, a): # pylint:disable=unused-argument
# if we got here, it's a cardinality of 1
return 1

Expand Down
10 changes: 7 additions & 3 deletions claripy/bv.py → claripy/backends/backend_concrete/bv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import functools
import numbers

from . import debug as _d
from .backend_object import BackendObject
from .errors import ClaripyOperationError, ClaripyTypeError, ClaripyZeroDivisionError
from claripy import debug as _d
from claripy.backends.backend_object import BackendObject
from claripy.errors import ClaripyOperationError, ClaripyTypeError, ClaripyZeroDivisionError


def compare_bits(f):
Expand Down Expand Up @@ -49,6 +49,10 @@ def normalize_helper(self, o):


class BVV(BackendObject):
"""A concrete bitvector value. Used in the concrete backend for calculations.
Any use outside of claripy should use `claripy.ast.bv.BVV` instead.
"""

__slots__ = ["bits", "_value", "mod"]

def __init__(self, value, bits):
Expand Down
Loading

0 comments on commit 0cf49e8

Please sign in to comment.