diff --git a/lark/common.py b/lark/common.py index 870b51547..d6be890a0 100644 --- a/lark/common.py +++ b/lark/common.py @@ -1,11 +1,12 @@ from copy import deepcopy import sys from types import ModuleType -from typing import Callable, Collection, Dict, Optional, TYPE_CHECKING +from typing import Callable, Collection, Dict, Optional, TYPE_CHECKING, List if TYPE_CHECKING: from .lark import PostLex from .lexer import Lexer + from .grammar import Rule from typing import Union, Type if sys.version_info >= (3, 8): from typing import Literal @@ -23,7 +24,8 @@ _ParserArgType: 'TypeAlias' = 'Literal["earley", "lalr", "cyk", "auto"]' _LexerArgType: 'TypeAlias' = 'Union[Literal["auto", "basic", "contextual", "dynamic", "dynamic_complete"], Type[Lexer]]' -_Callback = Callable[[Token], Token] +_LexerCallback = Callable[[Token], Token] +ParserCallbacks = Dict[str, Callable] class LexerConf(Serialize): __serialize_fields__ = 'terminals', 'ignore', 'g_regex_flags', 'use_bytes', 'lexer_type' @@ -33,7 +35,7 @@ class LexerConf(Serialize): re_module: ModuleType ignore: Collection[str] postlex: 'Optional[PostLex]' - callbacks: Dict[str, _Callback] + callbacks: Dict[str, _LexerCallback] g_regex_flags: int skip_validation: bool use_bytes: bool @@ -41,7 +43,7 @@ class LexerConf(Serialize): strict: bool def __init__(self, terminals: Collection[TerminalDef], re_module: ModuleType, ignore: Collection[str]=(), postlex: 'Optional[PostLex]'=None, - callbacks: Optional[Dict[str, _Callback]]=None, g_regex_flags: int=0, skip_validation: bool=False, use_bytes: bool=False, strict: bool=False): + callbacks: Optional[Dict[str, _LexerCallback]]=None, g_regex_flags: int=0, skip_validation: bool=False, use_bytes: bool=False, strict: bool=False): self.terminals = terminals self.terminals_by_name = {t.name: t for t in self.terminals} assert len(self.terminals) == len(self.terminals_by_name) @@ -70,16 +72,18 @@ def __deepcopy__(self, memo=None): deepcopy(self.use_bytes, memo), ) - class ParserConf(Serialize): __serialize_fields__ = 'rules', 'start', 'parser_type' - def __init__(self, rules, callbacks, start): + rules: List['Rule'] + callbacks: ParserCallbacks + start: List[str] + parser_type: _ParserArgType + + def __init__(self, rules: List['Rule'], callbacks: ParserCallbacks, start: List[str]): assert isinstance(start, list) self.rules = rules self.callbacks = callbacks self.start = start - self.parser_type = None - ###} diff --git a/lark/exceptions.py b/lark/exceptions.py index 32f0930a7..e099d596c 100644 --- a/lark/exceptions.py +++ b/lark/exceptions.py @@ -50,6 +50,7 @@ class UnexpectedInput(LarkError): pos_in_stream = None state: Any _terminals_by_name = None + interactive_parser: 'InteractiveParser' def get_context(self, text: str, span: int=40) -> str: """Returns a pretty string pinpointing the error in the text, @@ -225,7 +226,6 @@ class UnexpectedToken(ParseError, UnexpectedInput): expected: Set[str] considered_rules: Set[str] - interactive_parser: 'InteractiveParser' def __init__(self, token, expected, considered_rules=None, state=None, interactive_parser=None, terminals_by_name=None, token_history=None): super(UnexpectedToken, self).__init__() diff --git a/lark/grammar.py b/lark/grammar.py index 4f4fa90b5..1d226d9e4 100644 --- a/lark/grammar.py +++ b/lark/grammar.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, ClassVar +from typing import Optional, Tuple, ClassVar, Sequence from .utils import Serialize @@ -93,7 +93,15 @@ class Rule(Serialize): __serialize_fields__ = 'origin', 'expansion', 'order', 'alias', 'options' __serialize_namespace__ = Terminal, NonTerminal, RuleOptions - def __init__(self, origin, expansion, order=0, alias=None, options=None): + origin: NonTerminal + expansion: Sequence[Symbol] + order: int + alias: Optional[str] + options: RuleOptions + _hash: int + + def __init__(self, origin: NonTerminal, expansion: Sequence[Symbol], + order: int=0, alias: Optional[str]=None, options: Optional[RuleOptions]=None): self.origin = origin self.expansion = expansion self.alias = alias diff --git a/lark/lexer.py b/lark/lexer.py index 2fba894f0..9061d6001 100644 --- a/lark/lexer.py +++ b/lark/lexer.py @@ -15,6 +15,7 @@ pass if TYPE_CHECKING: from .common import LexerConf + from .parsers.lalr_parser_state import ParserState from .utils import classify, get_regexp_width, Serialize, logger from .exceptions import UnexpectedCharacters, LexError, UnexpectedToken @@ -436,7 +437,7 @@ def __init__(self, lexer: 'Lexer', lexer_state: LexerState): self.state = lexer_state @classmethod - def from_text(cls, lexer: 'Lexer', text: str): + def from_text(cls, lexer: 'Lexer', text: str) -> 'LexerThread': return cls(lexer, LexerState(text)) def lex(self, parser_state): @@ -622,13 +623,12 @@ def next_token(self, lex_state: LexerState, parser_state: Any = None) -> Token: class ContextualLexer(Lexer): - - lexers: Dict[str, AbstractBasicLexer] + lexers: Dict[int, AbstractBasicLexer] root_lexer: AbstractBasicLexer BasicLexer: Type[AbstractBasicLexer] = BasicLexer - def __init__(self, conf: 'LexerConf', states: Dict[str, Collection[str]], always_accept: Collection[str]=()) -> None: + def __init__(self, conf: 'LexerConf', states: Dict[int, Collection[str]], always_accept: Collection[str]=()) -> None: terminals = list(conf.terminals) terminals_by_name = conf.terminals_by_name @@ -658,7 +658,7 @@ def __init__(self, conf: 'LexerConf', states: Dict[str, Collection[str]], always trad_conf.skip_validation = True # We don't need to verify all terminals again self.root_lexer = self.BasicLexer(trad_conf, comparator) - def lex(self, lexer_state: LexerState, parser_state: Any) -> Iterator[Token]: + def lex(self, lexer_state: LexerState, parser_state: 'ParserState') -> Iterator[Token]: try: while True: lexer = self.lexers[parser_state.position] diff --git a/lark/load_grammar.py b/lark/load_grammar.py index fcd25c316..8e41775f3 100644 --- a/lark/load_grammar.py +++ b/lark/load_grammar.py @@ -684,7 +684,7 @@ def __init__(self, rule_defs: List[Tuple[str, Tuple[str, ...], Tree, RuleOptions self.rule_defs = rule_defs self.ignore = ignore - def compile(self, start, terminals_to_keep): + def compile(self, start, terminals_to_keep) -> Tuple[List[TerminalDef], List[Rule], List[str]]: # We change the trees in-place (to support huge grammars) # So deepcopy allows calling compile more than once. term_defs = [(n, (nr_deepcopy_tree(t), p)) for n, (t, p) in self.term_defs] @@ -733,7 +733,7 @@ def compile(self, start, terminals_to_keep): ebnf_to_bnf.prefix = name anon_tokens_transf.rule_options = rule_options tree = transformer.transform(rule_tree) - res = ebnf_to_bnf.transform(tree) + res: Tree = ebnf_to_bnf.transform(tree) rules.append((name, res, options)) rules += ebnf_to_bnf.new_rules @@ -743,7 +743,7 @@ def compile(self, start, terminals_to_keep): rule_tree_to_text = RuleTreeToText() simplify_rule = SimplifyRule_Visitor() - compiled_rules = [] + compiled_rules: List[Rule] = [] for rule_content in rules: name, tree, options = rule_content simplify_rule.visit(tree) @@ -753,7 +753,7 @@ def compile(self, start, terminals_to_keep): if alias and name.startswith('_'): raise GrammarError("Rule %s is marked for expansion (it starts with an underscore) and isn't allowed to have aliases (alias=%s)"% (name, alias)) - empty_indices = [x==_EMPTY for x in expansion] + empty_indices = tuple(x==_EMPTY for x in expansion) if any(empty_indices): exp_options = copy(options) or RuleOptions() exp_options.empty_indices = empty_indices @@ -764,6 +764,7 @@ def compile(self, start, terminals_to_keep): for sym in expansion: assert isinstance(sym, Symbol) if sym.is_term and exp_options and exp_options.keep_all_tokens: + assert isinstance(sym, Terminal) sym.filter_out = False rule = Rule(NonTerminal(name), expansion, i, alias, exp_options) compiled_rules.append(rule) diff --git a/lark/parser_frontends.py b/lark/parser_frontends.py index e7b19ff1d..186058a6b 100644 --- a/lark/parser_frontends.py +++ b/lark/parser_frontends.py @@ -1,14 +1,17 @@ -from typing import Any, Callable, Dict, Optional, Collection +from typing import Any, Callable, Dict, Optional, Collection, Union, TYPE_CHECKING from .exceptions import ConfigurationError, GrammarError, assert_config from .utils import get_regexp_width, Serialize -from .parsers.grammar_analysis import GrammarAnalyzer from .lexer import LexerThread, BasicLexer, ContextualLexer, Lexer from .parsers import earley, xearley, cyk from .parsers.lalr_parser import LALR_Parser from .tree import Tree from .common import LexerConf, ParserConf, _ParserArgType, _LexerArgType +if TYPE_CHECKING: + from .parsers.lalr_analysis import ParseTableBase + + ###{standalone def _wrap_lexer(lexer_class): @@ -90,7 +93,7 @@ def _verify_start(self, start=None): raise ConfigurationError("Unknown start rule %s. Must be one of %r" % (start, self.parser_conf.start)) return start - def _make_lexer_thread(self, text: str): + def _make_lexer_thread(self, text: str) -> Union[str, LexerThread]: cls = (self.options and self.options._plugins.get('LexerThread')) or LexerThread return text if self.skip_lexer else cls.from_text(self.lexer, text) @@ -146,7 +149,8 @@ def create_basic_lexer(lexer_conf, parser, postlex, options) -> BasicLexer: def create_contextual_lexer(lexer_conf: LexerConf, parser, postlex, options) -> ContextualLexer: cls = (options and options._plugins.get('ContextualLexer')) or ContextualLexer - states: Dict[str, Collection[str]] = {idx:list(t.keys()) for idx, t in parser._parse_table.states.items()} + parse_table: ParseTableBase[int] = parser._parse_table + states: Dict[int, Collection[str]] = {idx:list(t.keys()) for idx, t in parse_table.states.items()} always_accept: Collection[str] = postlex.always_accept if postlex else () return cls(lexer_conf, states, always_accept=always_accept) @@ -215,7 +219,6 @@ def create_earley_parser(lexer_conf: LexerConf, parser_conf: ParserConf, options class CYK_FrontEnd: def __init__(self, lexer_conf, parser_conf, options=None): - # self._analysis = GrammarAnalyzer(parser_conf) self.parser = cyk.Parser(parser_conf.rules) self.callbacks = parser_conf.callbacks diff --git a/lark/parsers/grammar_analysis.py b/lark/parsers/grammar_analysis.py index 5a98e4ebf..b52e50d5f 100644 --- a/lark/parsers/grammar_analysis.py +++ b/lark/parsers/grammar_analysis.py @@ -1,16 +1,20 @@ "Provides for superficial grammar analysis." from collections import Counter, defaultdict +from typing import List, Dict, Iterator, FrozenSet, Set from ..utils import bfs, fzset, classify from ..exceptions import GrammarError -from ..grammar import Rule, Terminal, NonTerminal +from ..grammar import Rule, Terminal, NonTerminal, Symbol +from ..common import ParserConf class RulePtr: __slots__ = ('rule', 'index') + rule: Rule + index: int - def __init__(self, rule, index): + def __init__(self, rule: Rule, index: int): assert isinstance(rule, Rule) assert index <= len(rule.expansion) self.rule = rule @@ -22,27 +26,37 @@ def __repr__(self): return '<%s : %s * %s>' % (self.rule.origin.name, ' '.join(before), ' '.join(after)) @property - def next(self): + def next(self) -> Symbol: return self.rule.expansion[self.index] - def advance(self, sym): + def advance(self, sym: Symbol) -> 'RulePtr': assert self.next == sym return RulePtr(self.rule, self.index+1) @property - def is_satisfied(self): + def is_satisfied(self) -> bool: return self.index == len(self.rule.expansion) - def __eq__(self, other): + def __eq__(self, other) -> bool: + if not isinstance(other, RulePtr): + return NotImplemented return self.rule == other.rule and self.index == other.index - def __hash__(self): + + def __hash__(self) -> int: return hash((self.rule, self.index)) +State = FrozenSet[RulePtr] + # state generation ensures no duplicate LR0ItemSets class LR0ItemSet: __slots__ = ('kernel', 'closure', 'transitions', 'lookaheads') + kernel: State + closure: State + transitions: Dict[Symbol, 'LR0ItemSet'] + lookaheads: Dict[Symbol, Set[Rule]] + def __init__(self, kernel, closure): self.kernel = fzset(kernel) self.closure = fzset(closure) @@ -124,7 +138,7 @@ def calculate_sets(rules): class GrammarAnalyzer: - def __init__(self, parser_conf, debug=False, strict=False): + def __init__(self, parser_conf: ParserConf, debug: bool=False, strict: bool=False): self.debug = debug self.strict = strict @@ -132,7 +146,7 @@ def __init__(self, parser_conf, debug=False, strict=False): for start in parser_conf.start} rules = parser_conf.rules + list(root_rules.values()) - self.rules_by_origin = classify(rules, lambda r: r.origin) + self.rules_by_origin: Dict[NonTerminal, List[Rule]] = classify(rules, lambda r: r.origin) if len(rules) != len(set(rules)): duplicates = [item for item, count in Counter(rules).items() if count > 1] @@ -163,14 +177,14 @@ def __init__(self, parser_conf, debug=False, strict=False): self.FIRST, self.FOLLOW, self.NULLABLE = calculate_sets(rules) - def expand_rule(self, source_rule, rules_by_origin=None): + def expand_rule(self, source_rule: NonTerminal, rules_by_origin=None) -> State: "Returns all init_ptrs accessible by rule (recursive)" if rules_by_origin is None: rules_by_origin = self.rules_by_origin init_ptrs = set() - def _expand_rule(rule): + def _expand_rule(rule: NonTerminal) -> Iterator[NonTerminal]: assert not rule.is_term, rule for r in rules_by_origin[rule]: @@ -180,6 +194,7 @@ def _expand_rule(rule): if r.expansion: # if not empty rule new_r = init_ptr.next if not new_r.is_term: + assert isinstance(new_r, NonTerminal) yield new_r for _ in bfs([source_rule], _expand_rule): diff --git a/lark/parsers/lalr_analysis.py b/lark/parsers/lalr_analysis.py index 7373ebd36..b7b3fdfc7 100644 --- a/lark/parsers/lalr_analysis.py +++ b/lark/parsers/lalr_analysis.py @@ -6,13 +6,15 @@ # Author: Erez Shinan (2017) # Email : erezshin@gmail.com +from typing import Dict, Set, Iterator, Tuple, List, TypeVar, Generic from collections import defaultdict from ..utils import classify, classify_bool, bfs, fzset, Enumerator, logger from ..exceptions import GrammarError -from .grammar_analysis import GrammarAnalyzer, Terminal, LR0ItemSet -from ..grammar import Rule +from .grammar_analysis import GrammarAnalyzer, Terminal, LR0ItemSet, RulePtr, State +from ..grammar import Rule, Symbol +from ..common import ParserConf ###{standalone @@ -27,8 +29,13 @@ def __repr__(self): Shift = Action('Shift') Reduce = Action('Reduce') +StateT = TypeVar("StateT") + +class ParseTableBase(Generic[StateT]): + states: Dict[StateT, Dict[str, Tuple]] + start_states: Dict[str, StateT] + end_states: Dict[str, StateT] -class ParseTable: def __init__(self, states, start_states, end_states): self.states = states self.start_states = start_states @@ -60,13 +67,21 @@ def deserialize(cls, data, memo): } return cls(states, data['start_states'], data['end_states']) +class ParseTable(ParseTableBase['State']): + """Parse-table whose key is State, i.e. set[RulePtr] + + Slower than IntParseTable, but useful for debugging + """ + pass + -class IntParseTable(ParseTable): +class IntParseTable(ParseTableBase[int]): + """Parse-table whose key is int. Best for performance.""" @classmethod - def from_ParseTable(cls, parse_table): + def from_ParseTable(cls, parse_table: ParseTable): enum = list(parse_table.states) - state_to_idx = {s:i for i,s in enumerate(enum)} + state_to_idx: Dict['State', int] = {s:i for i,s in enumerate(enum)} int_states = {} for s, la in parse_table.states.items(): @@ -131,7 +146,15 @@ def traverse(x, S, N, X, R, G, F): class LALR_Analyzer(GrammarAnalyzer): - def __init__(self, parser_conf, debug=False, strict=False): + lr0_itemsets: Set[LR0ItemSet] + nonterminal_transitions: List[Tuple[LR0ItemSet, Symbol]] + lookback: Dict[Tuple[LR0ItemSet, Symbol], Set[Tuple[LR0ItemSet, Rule]]] + includes: Dict[Tuple[LR0ItemSet, Symbol], Set[Tuple[LR0ItemSet, Symbol]]] + reads: Dict[Tuple[LR0ItemSet, Symbol], Set[Tuple[LR0ItemSet, Symbol]]] + directly_reads: Dict[Tuple[LR0ItemSet, Symbol], Set[Symbol]] + + + def __init__(self, parser_conf: ParserConf, debug: bool=False, strict: bool=False): GrammarAnalyzer.__init__(self, parser_conf, debug, strict) self.nonterminal_transitions = [] self.directly_reads = defaultdict(set) @@ -140,12 +163,12 @@ def __init__(self, parser_conf, debug=False, strict=False): self.lookback = defaultdict(set) - def compute_lr0_states(self): - self.lr0_states = set() + def compute_lr0_states(self) -> None: + self.lr0_itemsets = set() # map of kernels to LR0ItemSets - cache = {} + cache: Dict['State', LR0ItemSet] = {} - def step(state): + def step(state: LR0ItemSet) -> Iterator[LR0ItemSet]: _, unsat = classify_bool(state.closure, lambda rp: rp.is_satisfied) d = classify(unsat, lambda rp: rp.next) @@ -163,7 +186,7 @@ def step(state): state.transitions[sym] = new_state yield new_state - self.lr0_states.add(state) + self.lr0_itemsets.add(state) for _ in bfs(self.lr0_start_states.values(), step): pass @@ -176,7 +199,7 @@ def compute_reads_relations(self): assert(rp.index == 0) self.directly_reads[(root, rp.next)] = set([ Terminal('$END') ]) - for state in self.lr0_states: + for state in self.lr0_itemsets: seen = set() for rp in state.closure: if rp.is_satisfied: @@ -241,21 +264,22 @@ def compute_lookaheads(self): for s in follow_sets[nt]: state.lookaheads[s].add(rule) - def compute_lalr1_states(self): - m = {} + def compute_lalr1_states(self) -> None: + m: Dict[LR0ItemSet, Dict[str, Tuple]] = {} reduce_reduce = [] - for state in self.lr0_states: - actions = {la: (Shift, next_state.closure) for la, next_state in state.transitions.items()} - for la, rules in state.lookaheads.items(): + for itemset in self.lr0_itemsets: + actions: Dict[Symbol, Tuple] = {la: (Shift, next_state.closure) + for la, next_state in itemset.transitions.items()} + for la, rules in itemset.lookaheads.items(): if len(rules) > 1: # Try to resolve conflict based on priority p = [(r.options.priority or 0, r) for r in rules] p.sort(key=lambda r: r[0], reverse=True) best, second_best = p[:2] if best[0] > second_best[0]: - rules = [best[1]] + rules = {best[1]} else: - reduce_reduce.append((state, la, rules)) + reduce_reduce.append((itemset, la, rules)) continue rule ,= rules @@ -269,30 +293,31 @@ def compute_lalr1_states(self): logger.debug('Shift/Reduce conflict for terminal %s: (resolving as shift)', la.name) logger.debug(' * %s', rule) else: - actions[la] = (Reduce, list(rules)[0]) - m[state] = { k.name: v for k, v in actions.items() } + actions[la] = (Reduce, rule) + m[itemset] = { k.name: v for k, v in actions.items() } if reduce_reduce: msgs = [] - for state, la, rules in reduce_reduce: + for itemset, la, rules in reduce_reduce: msg = 'Reduce/Reduce collision in %s between the following rules: %s' % (la, ''.join([ '\n\t- ' + str(r) for r in rules ])) if self.debug: - msg += '\n collision occurred in state: {%s\n }' % ''.join(['\n\t' + str(x) for x in state.closure]) + msg += '\n collision occurred in state: {%s\n }' % ''.join(['\n\t' + str(x) for x in itemset.closure]) msgs.append(msg) raise GrammarError('\n\n'.join(msgs)) states = { k.closure: v for k, v in m.items() } # compute end states - end_states = {} + end_states: Dict[str, 'State'] = {} for state in states: for rp in state: for start in self.lr0_start_states: if rp.rule.origin.name == ('$root_' + start) and rp.is_satisfied: - assert(start not in end_states) + assert start not in end_states end_states[start] = state - _parse_table = ParseTable(states, { start: state.closure for start, state in self.lr0_start_states.items() }, end_states) + start_states = { start: state.closure for start, state in self.lr0_start_states.items() } + _parse_table = ParseTable(states, start_states, end_states) if self.debug: self.parse_table = _parse_table diff --git a/lark/parsers/lalr_parser.py b/lark/parsers/lalr_parser.py index 5a37dc546..6ae2a04fd 100644 --- a/lark/parsers/lalr_parser.py +++ b/lark/parsers/lalr_parser.py @@ -2,19 +2,20 @@ """ # Author: Erez Shinan (2017) # Email : erezshin@gmail.com -from copy import deepcopy, copy -from typing import Dict, Any -from ..lexer import Token +from typing import Dict, Any, Optional +from ..lexer import Token, LexerThread from ..utils import Serialize +from ..common import ParserConf, ParserCallbacks -from .lalr_analysis import LALR_Analyzer, Shift, IntParseTable +from .lalr_analysis import LALR_Analyzer, IntParseTable, ParseTableBase from .lalr_interactive_parser import InteractiveParser from lark.exceptions import UnexpectedCharacters, UnexpectedInput, UnexpectedToken +from .lalr_parser_state import ParserState, ParseConf ###{standalone class LALR_Parser(Serialize): - def __init__(self, parser_conf, debug=False, strict=False): + def __init__(self, parser_conf: ParserConf, debug: bool=False, strict: bool=False): analysis = LALR_Analyzer(parser_conf, debug=debug, strict=strict) analysis.compute_lalr() callbacks = parser_conf.callbacks @@ -33,7 +34,7 @@ def deserialize(cls, data, memo, callbacks, debug=False): def serialize(self, memo: Any = None) -> Dict[str, Any]: return self._parse_table.serialize(memo) - def parse_interactive(self, lexer, start): + def parse_interactive(self, lexer: LexerThread, start: str): return self.parser.parse(lexer, start, start_interactive=True) def parse(self, lexer, start, on_error=None): @@ -69,101 +70,17 @@ def parse(self, lexer, start, on_error=None): e = e2 -class ParseConf: - __slots__ = 'parse_table', 'callbacks', 'start', 'start_state', 'end_state', 'states' - - def __init__(self, parse_table, callbacks, start): - self.parse_table = parse_table - - self.start_state = self.parse_table.start_states[start] - self.end_state = self.parse_table.end_states[start] - self.states = self.parse_table.states - - self.callbacks = callbacks - self.start = start - - -class ParserState: - __slots__ = 'parse_conf', 'lexer', 'state_stack', 'value_stack' - - def __init__(self, parse_conf, lexer, state_stack=None, value_stack=None): - self.parse_conf = parse_conf - self.lexer = lexer - self.state_stack = state_stack or [self.parse_conf.start_state] - self.value_stack = value_stack or [] - - @property - def position(self): - return self.state_stack[-1] - - # Necessary for match_examples() to work - def __eq__(self, other): - if not isinstance(other, ParserState): - return NotImplemented - return len(self.state_stack) == len(other.state_stack) and self.position == other.position - - def __copy__(self): - return type(self)( - self.parse_conf, - self.lexer, # XXX copy - copy(self.state_stack), - deepcopy(self.value_stack), - ) - - def copy(self): - return copy(self) - - def feed_token(self, token, is_end=False): - state_stack = self.state_stack - value_stack = self.value_stack - states = self.parse_conf.states - end_state = self.parse_conf.end_state - callbacks = self.parse_conf.callbacks - - while True: - state = state_stack[-1] - try: - action, arg = states[state][token.type] - except KeyError: - expected = {s for s in states[state].keys() if s.isupper()} - raise UnexpectedToken(token, expected, state=self, interactive_parser=None) - - assert arg != end_state - - if action is Shift: - # shift once and return - assert not is_end - state_stack.append(arg) - value_stack.append(token if token.type not in callbacks else callbacks[token.type](token)) - return - else: - # reduce+shift as many times as necessary - rule = arg - size = len(rule.expansion) - if size: - s = value_stack[-size:] - del state_stack[-size:] - del value_stack[-size:] - else: - s = [] - - value = callbacks[rule](s) if callbacks else s - - _action, new_state = states[state_stack[-1]][rule.origin.name] - assert _action is Shift - state_stack.append(new_state) - value_stack.append(value) - - if is_end and state_stack[-1] == end_state: - return value_stack[-1] - class _Parser: - def __init__(self, parse_table, callbacks, debug=False): + parse_table: ParseTableBase + callbacks: ParserCallbacks + debug: bool + + def __init__(self, parse_table: ParseTableBase, callbacks: ParserCallbacks, debug: bool=False): self.parse_table = parse_table self.callbacks = callbacks self.debug = debug - def parse(self, lexer, start, value_stack=None, state_stack=None, start_interactive=False): + def parse(self, lexer: LexerThread, start: str, value_stack=None, state_stack=None, start_interactive=False): parse_conf = ParseConf(self.parse_table, self.callbacks, start) parser_state = ParserState(parse_conf, lexer, state_stack, value_stack) if start_interactive: @@ -171,16 +88,17 @@ def parse(self, lexer, start, value_stack=None, state_stack=None, start_interact return self.parse_from_state(parser_state) - def parse_from_state(self, state, last_token=None): + def parse_from_state(self, state: ParserState, last_token: Optional[Token]=None): """Run the main LALR parser loop Parameters: - state (ParseState) - the initial state. Changed in-place. - last_token (optional, Token) - Used only for line information in case of an empty lexer. + state - the initial state. Changed in-place. + last_token - Used only for line information in case of an empty lexer. """ try: token = last_token for token in state.lexer.lex(state): + assert token is not None state.feed_token(token) end_token = Token.new_borrow_pos('$END', '', token) if token else Token('$END', '', 0, 1, 1) diff --git a/lark/parsers/lalr_parser_state.py b/lark/parsers/lalr_parser_state.py new file mode 100644 index 000000000..350569769 --- /dev/null +++ b/lark/parsers/lalr_parser_state.py @@ -0,0 +1,110 @@ +from copy import deepcopy, copy +from typing import Dict, Any, Generic, List +from ..lexer import Token, LexerThread +from ..common import ParserCallbacks + +from .lalr_analysis import Shift, ParseTableBase, StateT +from lark.exceptions import UnexpectedToken + +###{standalone + +class ParseConf(Generic[StateT]): + __slots__ = 'parse_table', 'callbacks', 'start', 'start_state', 'end_state', 'states' + + parse_table: ParseTableBase[StateT] + callbacks: ParserCallbacks + start: str + + start_state: StateT + end_state: StateT + states: Dict[StateT, Dict[str, tuple]] + + def __init__(self, parse_table: ParseTableBase[StateT], callbacks: ParserCallbacks, start: str): + self.parse_table = parse_table + + self.start_state = self.parse_table.start_states[start] + self.end_state = self.parse_table.end_states[start] + self.states = self.parse_table.states + + self.callbacks = callbacks + self.start = start + +class ParserState(Generic[StateT]): + __slots__ = 'parse_conf', 'lexer', 'state_stack', 'value_stack' + + parse_conf: ParseConf[StateT] + lexer: LexerThread + state_stack: List[StateT] + value_stack: list + + def __init__(self, parse_conf: ParseConf[StateT], lexer: LexerThread, state_stack=None, value_stack=None): + self.parse_conf = parse_conf + self.lexer = lexer + self.state_stack = state_stack or [self.parse_conf.start_state] + self.value_stack = value_stack or [] + + @property + def position(self) -> StateT: + return self.state_stack[-1] + + # Necessary for match_examples() to work + def __eq__(self, other) -> bool: + if not isinstance(other, ParserState): + return NotImplemented + return len(self.state_stack) == len(other.state_stack) and self.position == other.position + + def __copy__(self): + return type(self)( + self.parse_conf, + self.lexer, # XXX copy + copy(self.state_stack), + deepcopy(self.value_stack), + ) + + def copy(self) -> 'ParserState[StateT]': + return copy(self) + + def feed_token(self, token: Token, is_end=False) -> Any: + state_stack = self.state_stack + value_stack = self.value_stack + states = self.parse_conf.states + end_state = self.parse_conf.end_state + callbacks = self.parse_conf.callbacks + + while True: + state = state_stack[-1] + try: + action, arg = states[state][token.type] + except KeyError: + expected = {s for s in states[state].keys() if s.isupper()} + raise UnexpectedToken(token, expected, state=self, interactive_parser=None) + + assert arg != end_state + + if action is Shift: + # shift once and return + assert not is_end + state_stack.append(arg) + value_stack.append(token if token.type not in callbacks else callbacks[token.type](token)) + return + else: + # reduce+shift as many times as necessary + rule = arg + size = len(rule.expansion) + if size: + s = value_stack[-size:] + del state_stack[-size:] + del value_stack[-size:] + else: + s = [] + + value = callbacks[rule](s) if callbacks else s + + _action, new_state = states[state_stack[-1]][rule.origin.name] + assert _action is Shift + state_stack.append(new_state) + value_stack.append(value) + + if is_end and state_stack[-1] == end_state: + return value_stack[-1] +###} diff --git a/lark/tools/standalone.py b/lark/tools/standalone.py index 3ae2cdb6d..1901f71b0 100644 --- a/lark/tools/standalone.py +++ b/lark/tools/standalone.py @@ -25,11 +25,10 @@ # from abc import ABC, abstractmethod -from collections.abc import Sequence from types import ModuleType from typing import ( TypeVar, Generic, Type, Tuple, List, Dict, Iterator, Collection, Callable, Optional, FrozenSet, Any, - Union, Iterable, IO, TYPE_CHECKING, overload, + Union, Iterable, IO, TYPE_CHECKING, overload, Sequence, Pattern as REPattern, ClassVar, Set, Mapping ) ###} @@ -63,8 +62,9 @@ 'lexer.py', 'common.py', 'parse_tree_builder.py', - 'parsers/lalr_parser.py', 'parsers/lalr_analysis.py', + 'parsers/lalr_parser_state.py', + 'parsers/lalr_parser.py', 'parser_frontends.py', 'lark.py', 'indenter.py', diff --git a/lark/utils.py b/lark/utils.py index 20cd6e298..97db71199 100644 --- a/lark/utils.py +++ b/lark/utils.py @@ -265,13 +265,13 @@ def __repr__(self): return '{%s}' % ', '.join(map(repr, self)) -def classify_bool(seq: Sequence, pred: Callable) -> Any: +def classify_bool(seq: Iterable, pred: Callable) -> Any: false_elems = [] true_elems = [elem for elem in seq if pred(elem) or false_elems.append(elem)] # type: ignore[func-returns-value] return true_elems, false_elems -def bfs(initial: Sequence, expand: Callable) -> Iterator: +def bfs(initial: Iterable, expand: Callable) -> Iterator: open_q = deque(list(initial)) visited = set(open_q) while open_q: