Skip to content

Commit

Permalink
Add support for comparing the equality of trees (#220)
Browse files Browse the repository at this point in the history
  • Loading branch information
renatahodovan authored May 24, 2024
1 parent 5ee7407 commit 2db57db
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion grammarinator/runtime/rule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2017-2023 Renata Hodovan, Akos Kiss.
# Copyright (c) 2017-2024 Renata Hodovan, Akos Kiss.
#
# Licensed under the BSD 3-Clause License
# <LICENSE.rst or https://opensource.org/licenses/BSD-3-Clause>.
Expand Down Expand Up @@ -202,6 +202,18 @@ def remove(self):
self.parent.children.remove(self)
self.parent = None

def equals(self, other):
"""
Compare two nodes (potentially including any children) for equality.
The comparison is not implemented within ``__eq__`` to ensure that
nodes can be added to collections based on identity.
:param Rule other: The node to compare the current node to.
:return: Whether the two nodes are equal.
:rtype: bool
"""
return self.__class__ == other.__class__ and self.name == other.name

def _dbg_(self):
"""
Called by :meth:`__format__` to compute the "debug" string
Expand Down Expand Up @@ -299,6 +311,9 @@ def __iadd__(self, item):
self.add_child(item)
return self

def equals(self, other):
return super().equals(other) and len(self.children) == len(other.children) and all(child.equals(other.children[i]) for i, child in enumerate(self.children))

def __str__(self):
return ''.join(str(child) for child in self.children)

Expand Down Expand Up @@ -367,6 +382,9 @@ def __init__(self, *, name=None, src=None, size=None):
self.src = src or ''
self.size = size or (RuleSize(depth=1, tokens=1) if src else RuleSize(depth=0, tokens=0))

def equals(self, other):
return super().equals(other) and self.src == other.src

def __str__(self):
return self.src

Expand Down Expand Up @@ -399,6 +417,9 @@ def __init__(self, *, idx, start, stop, children=None):
self.start = start
self.stop = stop

def equals(self, other):
return super().equals(other) and self.idx == other.idx and self.start == other.start and self.stop == other.stop

def __repr__(self):
parts = [
f'idx={self.idx}',
Expand Down Expand Up @@ -444,6 +465,9 @@ def __init__(self, *, alt_idx, idx, children=None):
self.alt_idx = alt_idx
self.idx = idx

def equals(self, other):
return super().equals(other) and self.alt_idx == other.alt_idx and self.idx == other.idx

def __repr__(self):
parts = [
f'alt_idx={self.alt_idx}',
Expand Down

0 comments on commit 2db57db

Please sign in to comment.