Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved typing around LALR and ParserState #1343

Merged
merged 4 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions lark/common.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from copy import deepcopy
import sys
from types import ModuleType
from typing import Callable, Collection, Dict, Optional, TYPE_CHECKING
from typing import Callable, Collection, Dict, Optional, TYPE_CHECKING, List

if TYPE_CHECKING:
from .lark import PostLex
from .lexer import Lexer
from .grammar import Rule
from typing import Union, Type
if sys.version_info >= (3, 8):
from typing import Literal
Expand All @@ -23,7 +24,8 @@

_ParserArgType: 'TypeAlias' = 'Literal["earley", "lalr", "cyk", "auto"]'
_LexerArgType: 'TypeAlias' = 'Union[Literal["auto", "basic", "contextual", "dynamic", "dynamic_complete"], Type[Lexer]]'
_Callback = Callable[[Token], Token]
_LexerCallback = Callable[[Token], Token]
ParserCallbacks = Dict[str, Callable]

class LexerConf(Serialize):
__serialize_fields__ = 'terminals', 'ignore', 'g_regex_flags', 'use_bytes', 'lexer_type'
Expand All @@ -33,15 +35,15 @@ class LexerConf(Serialize):
re_module: ModuleType
ignore: Collection[str]
postlex: 'Optional[PostLex]'
callbacks: Dict[str, _Callback]
callbacks: Dict[str, _LexerCallback]
g_regex_flags: int
skip_validation: bool
use_bytes: bool
lexer_type: Optional[_LexerArgType]
strict: bool

def __init__(self, terminals: Collection[TerminalDef], re_module: ModuleType, ignore: Collection[str]=(), postlex: 'Optional[PostLex]'=None,
callbacks: Optional[Dict[str, _Callback]]=None, g_regex_flags: int=0, skip_validation: bool=False, use_bytes: bool=False, strict: bool=False):
callbacks: Optional[Dict[str, _LexerCallback]]=None, g_regex_flags: int=0, skip_validation: bool=False, use_bytes: bool=False, strict: bool=False):
self.terminals = terminals
self.terminals_by_name = {t.name: t for t in self.terminals}
assert len(self.terminals) == len(self.terminals_by_name)
Expand Down Expand Up @@ -70,16 +72,18 @@ def __deepcopy__(self, memo=None):
deepcopy(self.use_bytes, memo),
)


class ParserConf(Serialize):
__serialize_fields__ = 'rules', 'start', 'parser_type'

def __init__(self, rules, callbacks, start):
rules: List['Rule']
callbacks: ParserCallbacks
start: List[str]
parser_type: _ParserArgType

def __init__(self, rules: List['Rule'], callbacks: ParserCallbacks, start: List[str]):
assert isinstance(start, list)
self.rules = rules
self.callbacks = callbacks
self.start = start

self.parser_type = None

###}
2 changes: 1 addition & 1 deletion lark/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class UnexpectedInput(LarkError):
pos_in_stream = None
state: Any
_terminals_by_name = None
interactive_parser: 'InteractiveParser'

def get_context(self, text: str, span: int=40) -> str:
"""Returns a pretty string pinpointing the error in the text,
Expand Down Expand Up @@ -225,7 +226,6 @@ class UnexpectedToken(ParseError, UnexpectedInput):

expected: Set[str]
considered_rules: Set[str]
interactive_parser: 'InteractiveParser'

def __init__(self, token, expected, considered_rules=None, state=None, interactive_parser=None, terminals_by_name=None, token_history=None):
super(UnexpectedToken, self).__init__()
Expand Down
12 changes: 10 additions & 2 deletions lark/grammar.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Tuple, ClassVar
from typing import Optional, Tuple, ClassVar, Sequence

from .utils import Serialize

Expand Down Expand Up @@ -93,7 +93,15 @@ class Rule(Serialize):
__serialize_fields__ = 'origin', 'expansion', 'order', 'alias', 'options'
__serialize_namespace__ = Terminal, NonTerminal, RuleOptions

def __init__(self, origin, expansion, order=0, alias=None, options=None):
origin: NonTerminal
expansion: Sequence[Symbol]
order: int
alias: Optional[str]
options: RuleOptions
_hash: int

def __init__(self, origin: NonTerminal, expansion: Sequence[Symbol],
order: int=0, alias: Optional[str]=None, options: Optional[RuleOptions]=None):
self.origin = origin
self.expansion = expansion
self.alias = alias
Expand Down
10 changes: 5 additions & 5 deletions lark/lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
pass
if TYPE_CHECKING:
from .common import LexerConf
from .parsers.lalr_parser_state import ParserState

from .utils import classify, get_regexp_width, Serialize, logger
from .exceptions import UnexpectedCharacters, LexError, UnexpectedToken
Expand Down Expand Up @@ -436,7 +437,7 @@ def __init__(self, lexer: 'Lexer', lexer_state: LexerState):
self.state = lexer_state

@classmethod
def from_text(cls, lexer: 'Lexer', text: str):
def from_text(cls, lexer: 'Lexer', text: str) -> 'LexerThread':
return cls(lexer, LexerState(text))

def lex(self, parser_state):
Expand Down Expand Up @@ -622,13 +623,12 @@ def next_token(self, lex_state: LexerState, parser_state: Any = None) -> Token:


class ContextualLexer(Lexer):

lexers: Dict[str, AbstractBasicLexer]
lexers: Dict[int, AbstractBasicLexer]
root_lexer: AbstractBasicLexer

BasicLexer: Type[AbstractBasicLexer] = BasicLexer

def __init__(self, conf: 'LexerConf', states: Dict[str, Collection[str]], always_accept: Collection[str]=()) -> None:
def __init__(self, conf: 'LexerConf', states: Dict[int, Collection[str]], always_accept: Collection[str]=()) -> None:
terminals = list(conf.terminals)
terminals_by_name = conf.terminals_by_name

Expand Down Expand Up @@ -658,7 +658,7 @@ def __init__(self, conf: 'LexerConf', states: Dict[str, Collection[str]], always
trad_conf.skip_validation = True # We don't need to verify all terminals again
self.root_lexer = self.BasicLexer(trad_conf, comparator)

def lex(self, lexer_state: LexerState, parser_state: Any) -> Iterator[Token]:
def lex(self, lexer_state: LexerState, parser_state: 'ParserState') -> Iterator[Token]:
try:
while True:
lexer = self.lexers[parser_state.position]
Expand Down
9 changes: 5 additions & 4 deletions lark/load_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ def __init__(self, rule_defs: List[Tuple[str, Tuple[str, ...], Tree, RuleOptions
self.rule_defs = rule_defs
self.ignore = ignore

def compile(self, start, terminals_to_keep):
def compile(self, start, terminals_to_keep) -> Tuple[List[TerminalDef], List[Rule], List[str]]:
# We change the trees in-place (to support huge grammars)
# So deepcopy allows calling compile more than once.
term_defs = [(n, (nr_deepcopy_tree(t), p)) for n, (t, p) in self.term_defs]
Expand Down Expand Up @@ -733,7 +733,7 @@ def compile(self, start, terminals_to_keep):
ebnf_to_bnf.prefix = name
anon_tokens_transf.rule_options = rule_options
tree = transformer.transform(rule_tree)
res = ebnf_to_bnf.transform(tree)
res: Tree = ebnf_to_bnf.transform(tree)
rules.append((name, res, options))
rules += ebnf_to_bnf.new_rules

Expand All @@ -743,7 +743,7 @@ def compile(self, start, terminals_to_keep):
rule_tree_to_text = RuleTreeToText()

simplify_rule = SimplifyRule_Visitor()
compiled_rules = []
compiled_rules: List[Rule] = []
for rule_content in rules:
name, tree, options = rule_content
simplify_rule.visit(tree)
Expand All @@ -753,7 +753,7 @@ def compile(self, start, terminals_to_keep):
if alias and name.startswith('_'):
raise GrammarError("Rule %s is marked for expansion (it starts with an underscore) and isn't allowed to have aliases (alias=%s)"% (name, alias))

empty_indices = [x==_EMPTY for x in expansion]
empty_indices = tuple(x==_EMPTY for x in expansion)
if any(empty_indices):
exp_options = copy(options) or RuleOptions()
exp_options.empty_indices = empty_indices
Expand All @@ -764,6 +764,7 @@ def compile(self, start, terminals_to_keep):
for sym in expansion:
assert isinstance(sym, Symbol)
if sym.is_term and exp_options and exp_options.keep_all_tokens:
assert isinstance(sym, Terminal)
sym.filter_out = False
rule = Rule(NonTerminal(name), expansion, i, alias, exp_options)
compiled_rules.append(rule)
Expand Down
13 changes: 8 additions & 5 deletions lark/parser_frontends.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from typing import Any, Callable, Dict, Optional, Collection
from typing import Any, Callable, Dict, Optional, Collection, Union, TYPE_CHECKING

from .exceptions import ConfigurationError, GrammarError, assert_config
from .utils import get_regexp_width, Serialize
from .parsers.grammar_analysis import GrammarAnalyzer
from .lexer import LexerThread, BasicLexer, ContextualLexer, Lexer
from .parsers import earley, xearley, cyk
from .parsers.lalr_parser import LALR_Parser
from .tree import Tree
from .common import LexerConf, ParserConf, _ParserArgType, _LexerArgType

if TYPE_CHECKING:
from .parsers.lalr_analysis import ParseTableBase


###{standalone

def _wrap_lexer(lexer_class):
Expand Down Expand Up @@ -90,7 +93,7 @@ def _verify_start(self, start=None):
raise ConfigurationError("Unknown start rule %s. Must be one of %r" % (start, self.parser_conf.start))
return start

def _make_lexer_thread(self, text: str):
def _make_lexer_thread(self, text: str) -> Union[str, LexerThread]:
cls = (self.options and self.options._plugins.get('LexerThread')) or LexerThread
return text if self.skip_lexer else cls.from_text(self.lexer, text)

Expand Down Expand Up @@ -146,7 +149,8 @@ def create_basic_lexer(lexer_conf, parser, postlex, options) -> BasicLexer:

def create_contextual_lexer(lexer_conf: LexerConf, parser, postlex, options) -> ContextualLexer:
cls = (options and options._plugins.get('ContextualLexer')) or ContextualLexer
states: Dict[str, Collection[str]] = {idx:list(t.keys()) for idx, t in parser._parse_table.states.items()}
parse_table: ParseTableBase[int] = parser._parse_table
states: Dict[int, Collection[str]] = {idx:list(t.keys()) for idx, t in parse_table.states.items()}
always_accept: Collection[str] = postlex.always_accept if postlex else ()
return cls(lexer_conf, states, always_accept=always_accept)

Expand Down Expand Up @@ -215,7 +219,6 @@ def create_earley_parser(lexer_conf: LexerConf, parser_conf: ParserConf, options

class CYK_FrontEnd:
def __init__(self, lexer_conf, parser_conf, options=None):
# self._analysis = GrammarAnalyzer(parser_conf)
self.parser = cyk.Parser(parser_conf.rules)

self.callbacks = parser_conf.callbacks
Expand Down
37 changes: 26 additions & 11 deletions lark/parsers/grammar_analysis.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
"Provides for superficial grammar analysis."

from collections import Counter, defaultdict
from typing import List, Dict, Iterator, FrozenSet, Set

from ..utils import bfs, fzset, classify
from ..exceptions import GrammarError
from ..grammar import Rule, Terminal, NonTerminal
from ..grammar import Rule, Terminal, NonTerminal, Symbol
from ..common import ParserConf


class RulePtr:
__slots__ = ('rule', 'index')
rule: Rule
index: int

def __init__(self, rule, index):
def __init__(self, rule: Rule, index: int):
assert isinstance(rule, Rule)
assert index <= len(rule.expansion)
self.rule = rule
Expand All @@ -22,27 +26,37 @@ def __repr__(self):
return '<%s : %s * %s>' % (self.rule.origin.name, ' '.join(before), ' '.join(after))

@property
def next(self):
def next(self) -> Symbol:
return self.rule.expansion[self.index]

def advance(self, sym):
def advance(self, sym: Symbol) -> 'RulePtr':
assert self.next == sym
return RulePtr(self.rule, self.index+1)

@property
def is_satisfied(self):
def is_satisfied(self) -> bool:
return self.index == len(self.rule.expansion)

def __eq__(self, other):
def __eq__(self, other) -> bool:
if not isinstance(other, RulePtr):
return NotImplemented
return self.rule == other.rule and self.index == other.index
def __hash__(self):

def __hash__(self) -> int:
return hash((self.rule, self.index))


State = FrozenSet[RulePtr]

# state generation ensures no duplicate LR0ItemSets
class LR0ItemSet:
__slots__ = ('kernel', 'closure', 'transitions', 'lookaheads')

kernel: State
closure: State
transitions: Dict[Symbol, 'LR0ItemSet']
lookaheads: Dict[Symbol, Set[Rule]]

def __init__(self, kernel, closure):
self.kernel = fzset(kernel)
self.closure = fzset(closure)
Expand Down Expand Up @@ -124,15 +138,15 @@ def calculate_sets(rules):


class GrammarAnalyzer:
def __init__(self, parser_conf, debug=False, strict=False):
def __init__(self, parser_conf: ParserConf, debug: bool=False, strict: bool=False):
self.debug = debug
self.strict = strict

root_rules = {start: Rule(NonTerminal('$root_' + start), [NonTerminal(start), Terminal('$END')])
for start in parser_conf.start}

rules = parser_conf.rules + list(root_rules.values())
self.rules_by_origin = classify(rules, lambda r: r.origin)
self.rules_by_origin: Dict[NonTerminal, List[Rule]] = classify(rules, lambda r: r.origin)

if len(rules) != len(set(rules)):
duplicates = [item for item, count in Counter(rules).items() if count > 1]
Expand Down Expand Up @@ -163,14 +177,14 @@ def __init__(self, parser_conf, debug=False, strict=False):

self.FIRST, self.FOLLOW, self.NULLABLE = calculate_sets(rules)

def expand_rule(self, source_rule, rules_by_origin=None):
def expand_rule(self, source_rule: NonTerminal, rules_by_origin=None) -> State:
"Returns all init_ptrs accessible by rule (recursive)"

if rules_by_origin is None:
rules_by_origin = self.rules_by_origin

init_ptrs = set()
def _expand_rule(rule):
def _expand_rule(rule: NonTerminal) -> Iterator[NonTerminal]:
assert not rule.is_term, rule

for r in rules_by_origin[rule]:
Expand All @@ -180,6 +194,7 @@ def _expand_rule(rule):
if r.expansion: # if not empty rule
new_r = init_ptr.next
if not new_r.is_term:
assert isinstance(new_r, NonTerminal)
yield new_r

for _ in bfs([source_rule], _expand_rule):
Expand Down
Loading
Loading