-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluator.py
194 lines (171 loc) · 6.99 KB
/
evaluator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, List, Set
from synth.syntax.program import Function, Primitive, Program, Variable
from synth.syntax.type_system import PrimitiveType
class Evaluator(ABC):
@abstractmethod
def eval(self, program: Program, input: Any) -> Any:
pass
@abstractmethod
def clear_cache(self) -> None:
"""
Clear any cache this evaluator might use.
"""
pass
def __tuplify__(element: Any) -> Any:
if isinstance(element, List):
return tuple(__tuplify__(x) for x in element)
else:
return element
def auto_complete_semantics(
primitives: Iterable[str], semantics: Dict[str, Any]
) -> None:
"""
Copy the semantics for all primitives that are not semantically defined yet there are defined up to prefix before @.
Examples:
1) and, and@0, and@1
Defining only and and then autocompleting will give the same semantic to the 3 previous primitives
2) or@0
Since or is not defined semantically then or@0 is not either.
"""
for prim in primitives:
if "@" in prim and prim not in semantics:
prefix = prim[: prim.index("@")]
if prefix in semantics:
semantics[prim] = semantics[prefix]
class DSLEvaluator(Evaluator):
def __init__(self, semantics: Dict[str, Any], use_cache: bool = True) -> None:
super().__init__()
self.semantics = semantics
self.use_cache = use_cache
self._cache: Dict[Any, Dict[Program, Any]] = {}
self._cons_cache: Dict[Any, Dict[Program, Any]] = {}
self.skip_exceptions: Set[Exception] = set()
# Statistics
self._total_requests = 0
self._cache_hits = 0
def eval(self, program: Program, input: List) -> Any:
key = __tuplify__(input)
if key not in self._cache and self.use_cache:
self._cache[key] = {}
evaluations: Dict[Program, Any] = self._cache[key] if self.use_cache else {}
if program in evaluations:
return evaluations[program]
try:
for sub_prog in program.depth_first_iter():
self._total_requests += 1
if sub_prog in evaluations:
self._cache_hits += 1
continue
if isinstance(sub_prog, Primitive):
evaluations[sub_prog] = self.semantics[sub_prog.primitive]
elif isinstance(sub_prog, Variable):
evaluations[sub_prog] = input[sub_prog.variable]
elif isinstance(sub_prog, Function):
fun = evaluations[sub_prog.function]
for arg in sub_prog.arguments:
fun = fun(evaluations[arg])
evaluations[sub_prog] = fun
except Exception as e:
if type(e) in self.skip_exceptions:
evaluations[program] = None
return None
else:
raise e
return evaluations[program]
def clear_cache(self) -> None:
self._cache = {}
self._cons_cache = {}
@property
def cache_hit_rate(self) -> float:
return self._cache_hits / self._total_requests
class DSLEvaluatorWithConstant(Evaluator):
def __init__(
self,
semantics: Dict[str, Any],
constant_types: Set[PrimitiveType],
use_cache: bool = True,
) -> None:
super().__init__()
self.semantics = semantics
self.constant_types = constant_types
self.use_cache = use_cache
self._cache: Dict[Any, Dict[Program, Any]] = {}
self._cons_cache: Dict[Any, Dict[Program, Any]] = {}
self._invariant_cache: Dict[Program, Any] = {}
self.skip_exceptions: Set[Exception] = set()
# Statistics
self._total_requests = 0
self._cache_hits = 0
def eval_with_constant(
self, program: Program, input: List, constant_in: str, constant_out: str
) -> Any:
evaluations: Dict[Program, Any] = {}
if self.use_cache:
used_cons = False
for sub_prog in program.depth_first_iter():
if (
isinstance(sub_prog, Primitive)
and sub_prog.type in self.constant_types
):
used_cons = True
break
if used_cons:
key = input.copy()
key.append(constant_in)
key.append(constant_out)
key = __tuplify__(key)
evaluations = self._cons_cache[key] if key in self._cons_cache else {}
else:
key = __tuplify__(input)
evaluations = self._cache[key] if key in self._cache else {}
if program in evaluations:
return evaluations[program]
try:
for sub_prog in program.depth_first_iter():
self._total_requests += 1
if sub_prog in evaluations:
self._cache_hits += 1
continue
if sub_prog.is_invariant(self.constant_types):
if sub_prog in self._invariant_cache:
self._cache_hits += 1
evaluations[sub_prog] = self._invariant_cache[sub_prog]
continue
else:
self._invariant_cache[sub_prog] = None
if isinstance(sub_prog, Primitive):
if sub_prog.primitive == "cste_in":
evaluations[sub_prog] = constant_in
elif sub_prog.primitive == "cste_out":
evaluations[sub_prog] = constant_out
else:
evaluations[sub_prog] = self.semantics[sub_prog.primitive]
elif isinstance(sub_prog, Variable):
evaluations[sub_prog] = input[sub_prog.variable]
elif isinstance(sub_prog, Function):
fun = evaluations[sub_prog.function]
for arg in sub_prog.arguments:
fun = fun(evaluations[arg])
evaluations[sub_prog] = fun
if sub_prog.is_invariant(self.constant_types):
self._invariant_cache[sub_prog] = evaluations[sub_prog]
except Exception as e:
if type(e) in self.skip_exceptions:
evaluations[program] = None
return None
else:
print(e)
raise e
return evaluations[program]
def eval(self, program: Program, input: List) -> Any:
if len(input) >= 3:
return self.eval_with_constant(program, input[2:], input[0], input[1])
return self.eval_with_constant(program, input, "", "")
def clear_cache(self) -> None:
self._cache = {}
self._cons_cache = {}
self._invariant_cache = {}
@property
def cache_hit_rate(self) -> float:
return self._cache_hits / self._total_requests