Skip to content

Commit

Permalink
perf(ir): don't recreate nodes in replace if their children haven't…
Browse files Browse the repository at this point in the history
… changed
  • Loading branch information
jcrist committed Sep 4, 2024
1 parent e6f66c6 commit 0c7e727
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 44 deletions.
23 changes: 11 additions & 12 deletions ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def sqlize(

# lower the expression graph to a SQL-like relational algebra
context = {"params": params}
sqlized = node.replace(
result = node.replace(
replace_parameter
| project_to_select
| filter_to_select
Expand All @@ -385,24 +385,23 @@ def sqlize(

# squash subsequent Select nodes into one
if fuse_selects:
simplified = sqlized.replace(merge_select_select)
else:
simplified = sqlized
result = result.replace(merge_select_select)

if post_rewrites:
simplified = simplified.replace(reduce(operator.or_, post_rewrites))
result = result.replace(reduce(operator.or_, post_rewrites))

# extract common table expressions while wrapping them in a CTE node
ctes = extract_ctes(simplified)
ctes = extract_ctes(result)

def wrap(node, _, **kwargs):
new = node.__recreate__(kwargs)
return CTE(new) if node in ctes else new
if ctes:

result = simplified.replace(wrap)
ctes = [cte.parent for cte in result.find(CTE, ordered=True)]
def apply_ctes(node, kwargs):
new = node.__recreate__(kwargs) if kwargs else node
return CTE(new) if node in ctes else new

return result, ctes
result = result.replace(apply_ctes)
return result, [cte.parent for cte in result.find(CTE, ordered=True)]
return result, []


# supplemental rewrites selectively used on a per-backend basis
Expand Down
16 changes: 7 additions & 9 deletions ibis/backends/tests/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1382,16 +1382,14 @@ def test_histogram(con, alltypes):
hist = con.execute(alltypes.int_col.histogram(n).name("hist"))
vc = hist.value_counts().sort_index()
vc_np, _bin_edges = np.histogram(alltypes.int_col.execute(), bins=n)
assert vc.tolist() == vc_np.tolist()
assert (
con.execute(
ibis.memtable({"value": range(100)})
.select(bin=_.value.histogram(10))
.value_counts()
.bin_count.nunique()
)
== 1
expr = (
ibis.memtable({"value": range(100)})
.select(bin=_.value.histogram(10))
.value_counts()
.bin_count.nunique()
)
assert vc.tolist() == vc_np.tolist()
assert con.execute(expr) == 1


@pytest.mark.parametrize("const", ["pi", "e"])
Expand Down
94 changes: 77 additions & 17 deletions ibis/common/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
Finder = Callable[["Node"], bool]
FinderLike = Union[Finder, Pattern, _ClassInfo]

Replacer = Callable[["Node", dict["Node", Any]], "Node"]
Replacer = Callable[["Node", dict["Node", Any] | None], "Node"]
ReplacerLike = Union[Replacer, Pattern, Mapping]


Expand Down Expand Up @@ -127,6 +127,47 @@ def _recursive_lookup(obj: Any, dct: dict) -> Any:
return obj


def _apply_replacements(obj: Any, replacements: dict) -> tuple[Any, bool]:
"""Replace nodes in a possibly nested object.
Parameters
----------
obj
The object to traverse.
replacements
A mapping of replacement values.
Returns
-------
tuple[Any, bool]
A tuple of the replaced object and whether any replacements were made.
"""
if isinstance(obj, Node):
val = replacements.get(obj)
return (obj, False) if val is None else (val, True)
typ = type(obj)
if typ in (tuple, frozenset, list):
changed = False
items = []
for i in obj:
i, ichanged = _apply_replacements(i, replacements)
changed |= ichanged
items.append(i)
return typ(items), changed
elif isinstance(obj, dict):
changed = False
items = {}
for k, v in obj.items():
k, kchanged = _apply_replacements(k, replacements)
v, vchanged = _apply_replacements(v, replacements)
changed |= kchanged
changed |= vchanged
items[k] = v
return items, changed
else:
return obj, False


def _coerce_finder(obj: FinderLike, context: Optional[dict] = None) -> Finder:
"""Coerce an object into a callable finder function.
Expand Down Expand Up @@ -165,8 +206,7 @@ def _coerce_replacer(obj: ReplacerLike, context: Optional[dict] = None) -> Repla
Parameters
----------
obj
A Pattern, a Mapping or a callable which can be fed to `node.map()`
to replace nodes.
A Pattern, Mapping, or Callable.
context
Optional context to use if the replacer is a pattern.
Expand All @@ -177,26 +217,26 @@ def _coerce_replacer(obj: ReplacerLike, context: Optional[dict] = None) -> Repla
"""
if isinstance(obj, Pattern):

def fn(node, _, **kwargs):
def fn(node, kwargs):
ctx = context or {}
# need to first reconstruct the node from the possible rewritten
# children, so we can match on the new node containing the rewritten
# child arguments, this way we can propagate the rewritten nodes
# upward in the hierarchy, using a specialized __recreate__ method
# improves the performance by 17% compared node.__class__(**kwargs)
recreated = node.__recreate__(kwargs)
# upward in the hierarchy
recreated = node.__recreate__(kwargs) if kwargs else node
if (result := obj.match(recreated, ctx)) is NoMatch:
return recreated
else:
return result
return result

elif isinstance(obj, Mapping):

def fn(node, _, **kwargs):
def fn(node, kwargs):
# For a mapping we want to lookup the original node first, and
# return a recreated one from the children if it's not present
try:
return obj[node]
except KeyError:
return node.__recreate__(kwargs)
return node.__recreate__(kwargs) if kwargs else node
elif callable(obj):
fn = obj
else:
Expand Down Expand Up @@ -313,7 +353,7 @@ def map_clear(self, fn: Callable, filter: Optional[Finder] = None) -> Any:
if not dependents[dependency]:
del results[dependency]

return results[self]
return results.get(self, self)

@experimental
def map_nodes(self, fn: Callable, filter: Optional[Finder] = None) -> Any:
Expand Down Expand Up @@ -451,8 +491,9 @@ def replace(
Parameters
----------
replacer
A `Pattern`, a `Mapping` or a callable which can be fed to
`node.map()` directly to replace nodes.
A `Pattern`, `Mapping` or Callable taking the original unrewritten
node, and a mapping of attribute name to value of its rewritten
children (or None if no children were rewritten).
filter
A type, tuple of types, a pattern or a callable to filter out nodes
from the traversal. The traversal will only visit nodes that match
Expand All @@ -465,9 +506,28 @@ def replace(
The root node of the graph with the replaced nodes.
"""
replacer = _coerce_replacer(replacer, context)
results = self.map(replacer, filter=filter)
return results.get(self, self)
replacements: dict[Node, Any] = {}

fn = _coerce_replacer(replacer, context)

graph, _ = Graph.from_bfs(self, filter=filter).toposort()
for node in graph:
kwargs = {}
# Apply already rewritten nodes to the children of the node
changed = False
for k, v in zip(node.__argnames__, node.__args__):
v, vchanged = _apply_replacements(v, replacements)
changed |= vchanged
kwargs[k] = v

# Call the replacer on the node with any rewritten nodes (or None
# if unchanged).
result = fn(node, kwargs if changed else None)
if result is not node:
# The node is changed, store it in the mapping of replacements
replacements[node] = result

return replacements.get(self, self)


class Graph(dict[Node, Sequence[Node]]):
Expand Down
43 changes: 37 additions & 6 deletions ibis/common/tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
traverse,
)
from ibis.common.grounds import Annotable, Concrete
from ibis.common.patterns import Eq, If, InstanceOf, Object, TupleOf, _
from ibis.common.patterns import Eq, If, InstanceOf, Object, TupleOf, _, pattern


class MyNode(Node):
Expand Down Expand Up @@ -170,6 +170,36 @@ def test_replace_with_mapping():
assert result == new_A


@pytest.mark.parametrize("kind", ["pattern", "mapping", "function"])
def test_replace_doesnt_recreate_unchanged_nodes(kind):
A1 = MyNode(name="A1", children=[])
A2 = MyNode(name="A2", children=[A1])
B1 = MyNode(name="B1", children=[])
B2 = MyNode(name="B2", children=[B1])
C = MyNode(name="C", children=[A2, B2])

B3 = MyNode(name="B3", children=[])

if kind == "pattern":
replacer = pattern(MyNode)(name="B2") >> B3
elif kind == "mapping":
replacer = {B2: B3}
else:

def replacer(node, children):
if node is B2:
return B3
return node.__recreate__(children) if children else node

res = C.replace(replacer)

assert res is not C
assert res.name == "C"
assert len(res.children) == 2
assert res.children[0] is A2
assert res.children[1] is B3


def test_example():
class Example(Annotable, Node):
def __hash__(self):
Expand Down Expand Up @@ -343,17 +373,18 @@ def test_coerce_finder():


def test_coerce_replacer():
r = _coerce_replacer(lambda x, _, **kwargs: D)
assert r(C, {}) == D
r = _coerce_replacer(lambda x, children: D if children else C)
assert r(C, {"children": []}) is D
assert r(C, None) is C

r = _coerce_replacer({C: D, D: E})
assert r(C, {}) == D
assert r(D, {}) == E
assert r(A, {}, name="A", children=[B, C]) == A
assert r(A, {"name": "A", "children": [B, C]}) == A

r = _coerce_replacer(InstanceOf(MyNode) >> _.copy(name=_.name.lower()))
assert r(C, {}, name="C", children=[]) == MyNode(name="c", children=[])
assert r(D, {}, name="D", children=[]) == MyNode(name="d", children=[])
assert r(C, {"name": "C", "children": []}) == MyNode(name="c", children=[])
assert r(D, {"name": "D", "children": []}) == MyNode(name="d", children=[])


def test_node_find_using_type():
Expand Down

0 comments on commit 0c7e727

Please sign in to comment.