Skip to content

Commit

Permalink
support translation to bounded expression
Browse files Browse the repository at this point in the history
  • Loading branch information
danbryce committed Oct 10, 2024
1 parent 8e7e6b5 commit cdb03a9
Show file tree
Hide file tree
Showing 8 changed files with 379 additions and 86 deletions.
261 changes: 181 additions & 80 deletions notebooks/monthly-demos/helpers.py

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion src/funman/model/petrinet.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,9 @@ def gradient(self, t, y, *p):
]
# print(f"vars: {self._state_var_names()}")
# print(f"gradient: {grad}")
assert not any([not isinstance(v, float) for v in grad]), f"Gradient has a non-float element: {grad}"
assert not any(
[not isinstance(v, float) for v in grad]
), f"Gradient has a non-float element: {grad}"
return grad


Expand Down
2 changes: 1 addition & 1 deletion src/funman/search/box_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,7 +1107,7 @@ def _expand(
"dreal_log_level": episode.config.dreal_log_level,
"dreal_mcts": episode.config.dreal_mcts,
"preferred": episode.config.dreal_prefer_parameters, # [p.name for p in episode.problem.parameters] if episode.config.dreal_prefer_parameters else [],
"random_seed": episode.config.random_seed
"random_seed": episode.config.random_seed,
}
else:
opts = {}
Expand Down
2 changes: 1 addition & 1 deletion src/funman/search/simulator_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def expand(
"dreal_log_level": episode.config.dreal_log_level,
"dreal_mcts": episode.config.dreal_mcts,
"preferred": episode.config.dreal_prefer_parameters, # [p.name for p in problem.model_parameters()]if episode.config.dreal_prefer_parameters else [],
"random_seed": episode.config.random_seed
"random_seed": episode.config.random_seed,
}
else:
opts = {}
Expand Down
2 changes: 1 addition & 1 deletion src/funman/search/smt_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def expand(
"dreal_log_level": episode.config.dreal_log_level,
"dreal_mcts": episode.config.dreal_mcts,
"preferred": episode.config.dreal_prefer_parameters, # [p.name for p in problem.model_parameters()]if episode.config.dreal_prefer_parameters else [],
"random_seed": episode.config.random_seed
"random_seed": episode.config.random_seed,
}
else:
opts = {}
Expand Down
7 changes: 5 additions & 2 deletions src/funman/translate/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,8 +1118,11 @@ def _encode_state_variable_constraint(
):
time = options.schedule.time_at_step(layer_idx)

if (( constraint.variable in scenario.model._observable_names() and
scenario.model.is_timed_observable(constraint.variable))
if (
(
constraint.variable in scenario.model._observable_names()
and scenario.model.is_timed_observable(constraint.variable)
)
or constraint.variable in scenario.model._state_var_names()
) and constraint.contains_time(time):
bounds = (
Expand Down
100 changes: 100 additions & 0 deletions src/funman/utils/sympy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pysmt.operators as op
import pysmt.typing as types
import sympy
from pydantic import BaseModel, ConfigDict
from pysmt.formula import FNode, FormulaManager
from pysmt.shortcuts import FALSE, GE, GT, LE, LT, REAL, TRUE
from pysmt.shortcuts import Abs as pysmt_Abs
Expand All @@ -28,6 +29,105 @@
l = logging.getLogger(__name__)


class SympyBoundedSubstituter(BaseModel):
model_config = ConfigDict(
arbitrary_types_allowed=True,
)

bound_symbols: Dict[str, Dict[str, str]] = {}
str_to_symbol: Dict[str, sympy.Symbol] = {}

def _prepare_expression(
self, derivative_variable: str, expr: str, bound: str
) -> sympy.Expr:
deriv_var_symbol = self.str_to_symbol[derivative_variable]
sym_expr = to_sympy(expr, self.str_to_symbol)
# substitute lb for deriv_var_symbol in sym_expr
lb_symbol = self.bound_symbols[derivative_variable][bound]
sym_expr = sym_expr.subs({derivative_variable: lb_symbol})
return sym_expr

def maximize(self, derivative_variable: str, expr: str) -> str:
sym_expr = self._prepare_expression(derivative_variable, expr, "ub")
m_expr = self._substitute(sym_expr, False)
return m_expr

def minimize(self, derivative_variable: str, expr: str) -> str:
sym_expr = self._prepare_expression(derivative_variable, expr, "lb")
m_expr = self._substitute(sym_expr, True)
return m_expr

def _substitute(self, expr: sympy.Expr, sub_min: bool):
func = expr.func
# if func.is_Boolean:
# return TRUE() if isinstance(expr, BooleanTrue) else FALSE()
if func.is_Mul:
return self._substitute_op(func, expr, sub_min)
elif func.is_Add:
return self._substitute_op(func, expr, sub_min)
# elif isinstance(expr, Abs):
# return self._substitute_abs(expr)
elif func.is_Symbol:
return self._substitute_symbol(expr, sub_min, op_type=REAL)
elif func.is_Pow:
return self._substitute_pow(expr, sub_min)
# elif isinstance(expr, exp):
# return Pow(self._substitute_real(math.e), self._substitute(expr.exp))
# elif expr.is_Boolean:
# return self._substitute_op(And, expr)
# elif func.is_Relational:
# if func.rel_op == "<=":
# return self._substitute_op(LE, expr, explode=True)
# elif func.rel_op == "<":
# return self._substitute_op(LT, expr, explode=True)
# elif func.rel_op == ">=":
# return self._substitute_op(GE, expr, explode=True)
# elif func.rel_op == ">":
# return self._substitute_op(GT, expr, explode=True)
# elif func.rel_op == "==":
# return self._substitute_op(Equals, expr, explode=True)
# elif expr.is_Piecewise:
# return self._substitute_piecewise(expr)
elif expr.is_constant():
return self._substitute_real(expr, sub_min)

else:
raise Exception(f"Could not convert expression: {expr}")

def _substitute_op(self, op, expr, sub_min: bool, explode=True):

next_min = (
not sub_min if sympy.core.function._coeff_isneg(expr) else sub_min
)
terms = [self._substitute(arg, next_min) for arg in expr.args]
return op(*terms) if explode else op(terms)

def _substitute_pow(self, expr, sub_min: bool):
base = expr.args[0]
exponent = expr.args[1]
next_min = (
not sub_min
if sympy.core.function._coeff_isneg(exponent)
else sub_min
)
return sympy.Pow(
self._substitute(base, next_min),
self._substitute(exponent, sub_min),
)

def _substitute_symbol(self, expr, sub_min: bool, op_type=REAL):
sym = str(expr)
bound = "lb" if sub_min else "ub"
return (
sympy.Symbol(self.bound_symbols[sym][bound])
if not sym.endswith("_lb") and not sym.endswith("_ub")
else expr
)

def _substitute_real(self, expr, sub_min: bool):
return expr


class SympySerializer(IdentityDagWalker):
def __init__(self):
super().__init__(invalidate_memoization=True)
Expand Down
87 changes: 87 additions & 0 deletions test/test_abstraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import logging
import sys
import unittest

import sympy

from funman.utils.sympy_utils import SympyBoundedSubstituter, to_sympy


class TestUseCases(unittest.TestCase):
l = logging.Logger(__name__)

def setUp(self):
logging.basicConfig(level=logging.DEBUG)
logging.getLogger().setLevel(logging.DEBUG)
self.l.level = logging.getLogger().level
self.l.handlers.append(logging.StreamHandler(sys.stdout))

def test_minimize_expression(self):
tests = [
{
"input": ["S", "- S * I * beta/N"],
"bound": "lb",
"expected_output": "-I_ub*S_lb*beta_ub/N_lb",
},
{
"input": ["S", "- S * I * beta/N"],
"bound": "ub",
"expected_output": "-I_lb*S_ub*beta_lb/N_ub",
},
{
"input": ["S", "- S * I * beta"],
"bound": "lb",
"expected_output": "-I_ub*S_lb*beta_ub",
},
{
"input": ["S", "- S * I * beta"],
"bound": "ub",
"expected_output": "-I_lb*S_ub*beta_lb",
},
{
"input": ["I", "S * I * beta - I * gamma"],
"bound": "lb",
"expected_output": "I_lb*S_lb*beta_lb - I_lb*gamma_ub",
},
{
"input": ["I", "S * I * beta - I * gamma"],
"bound": "ub",
"expected_output": "I_ub*S_ub*beta_ub - I_ub*gamma_lb",
},
{
"input": ["R", "I * gamma"],
"bound": "lb",
"expected_output": "I_lb*gamma_lb",
},
{
"input": ["I", "I * gamma"],
"bound": "ub",
"expected_output": "I_ub*gamma_ub",
},
]

str_symbols = ["S", "I", "R", "beta", "gamma", "N"]
symbols = {s: sympy.Symbol(s) for s in str_symbols}
bound_symbols = {
s: {"lb": f"{s}_lb", "ub": f"{s}_ub"} for s in str_symbols
}
substituter = SympyBoundedSubstituter(
bound_symbols=bound_symbols, str_to_symbol=symbols
)

for test in tests:
with self.subTest(f"{test['bound']}({test['input']})"):
test_fn = (
substituter.minimize
if test["bound"] == "lb"
else substituter.maximize
)
test_output = test_fn(*test["input"])
# self.l.debug(f"Minimized: [{infection_rate}], to get expression: [{test_output}]")
assert (
str(test_output) == test["expected_output"]
), f"Failed to create the expected expression: [{test['expected_output']}], got [{test_output}]"


if __name__ == "__main__":
unittest.main()

0 comments on commit cdb03a9

Please sign in to comment.