Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor[lang]: remove ASTTokens #4364

Draft
wants to merge 12 commits into
base: master
Choose a base branch
from
101 changes: 50 additions & 51 deletions vyper/ast/parse.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import ast as python_ast
import tokenize
from decimal import Decimal
from functools import cached_property
from typing import Any, Dict, List, Optional, Union

import asttokens

from vyper.ast import nodes as vy_ast
from vyper.ast.pre_parser import PreParser
from vyper.compiler.settings import Settings
Expand Down Expand Up @@ -138,16 +137,8 @@ def annotate_python_ast(
-------
The annotated and optimized AST.
"""
tokens = asttokens.ASTTokens(vyper_source)
assert isinstance(parsed_ast, python_ast.Module) # help mypy
tokens.mark_tokens(parsed_ast)
visitor = AnnotatingVisitor(
vyper_source,
pre_parser,
tokens,
source_id,
module_path=module_path,
resolved_path=resolved_path,
vyper_source, pre_parser, source_id, module_path=module_path, resolved_path=resolved_path
)
visitor.visit(parsed_ast)

Expand All @@ -162,20 +153,32 @@ def __init__(
self,
source_code: str,
pre_parser: PreParser,
tokens: asttokens.ASTTokens,
source_id: int,
module_path: Optional[str] = None,
resolved_path: Optional[str] = None,
):
self._tokens = tokens
self._source_id = source_id
self._module_path = module_path
self._resolved_path = resolved_path
self._source_code = source_code
self._parent = None
self._pre_parser = pre_parser

self.counter: int = 0

@cached_property
def source_lines(self):
return self._source_code.splitlines(keepends=True)

@cached_property
def line_offsets(self):
ofst = 0
ret = {}
for lineno, line in enumerate(self.source_lines):
ret[lineno + 1] = ofst
ofst += len(line)
return ret

def generic_visit(self, node):
"""
Annotate a node with information that simplifies Vyper node generation.
Expand All @@ -186,33 +189,34 @@ def generic_visit(self, node):
node.ast_type = node.__class__.__name__
self.counter += 1

# Decorate every node with source end offsets
start = (None, None)
if hasattr(node, "first_token"):
start = node.first_token.start
end = (None, None)
if hasattr(node, "last_token"):
end = node.last_token.end
if node.last_token.type == 4:
# token type 4 is a `\n`, some nodes include a trailing newline
# here we ignore it when building the node offsets
end = (end[0], end[1] - 1)

node.lineno = start[0]
node.col_offset = start[1]
node.end_lineno = end[0]
node.end_col_offset = end[1]

# TODO: adjust end_lineno and end_col_offset when this node is in
# modification_offsets

if hasattr(node, "last_token"):
start_pos = node.first_token.startpos
end_pos = node.last_token.endpos

if node.last_token.type == 4:
# ignore trailing newline once more
end_pos -= 1
if isinstance(node, python_ast.Module):
node.lineno = 1
node.col_offset = 0
node.end_lineno = len(self.source_lines)

if len(self.source_lines) > 0:
node.end_col_offset = len(self.source_lines[-1])
else:
node.end_col_offset = 0

adjustments = self._pre_parse_result.adjustments

for s in ("lineno", "end_lineno", "col_offset", "end_col_offset"):
# ensure fields exist
setattr(node, s, getattr(node, s, None))

if node.col_offset is not None:
adj = adjustments.get((node.lineno, node.col_offset), 0)
node.col_offset += adj

if node.end_col_offset is not None:
adj = adjustments.get((node.end_lineno, node.end_col_offset), 0)
node.end_col_offset += adj

if node.lineno in self.line_offsets and node.end_lineno in self.line_offsets:
start_pos = self.line_offsets[node.lineno] + node.col_offset
end_pos = self.line_offsets[node.end_lineno] + node.end_col_offset

node.src = f"{start_pos}:{end_pos-start_pos}:{self._source_id}"
node.node_source_code = self._source_code[start_pos:end_pos]

Expand Down Expand Up @@ -248,12 +252,6 @@ def visit_Module(self, node):
return self._visit_docstring(node)

def visit_FunctionDef(self, node):
if node.decorator_list:
# start the source highlight at `def` to improve annotation readability
decorator_token = node.decorator_list[-1].last_token
def_token = self._tokens.find_token(decorator_token, tokenize.NAME, tok_str="def")
node.first_token = def_token

return self._visit_docstring(node)

def visit_ClassDef(self, node):
Expand All @@ -269,6 +267,12 @@ def visit_ClassDef(self, node):
node.ast_type = self._pre_parser.modification_offsets[(node.lineno, node.col_offset)]
return node

def visit_Load(self, node):
return None

def visit_Store(self, node):
return None

def visit_For(self, node):
"""
Visit a For node, splicing in the loop variable annotation provided by
Expand Down Expand Up @@ -314,11 +318,6 @@ def visit_For(self, node):
"invalid type annotation", self._source_code, node.lineno, node.col_offset
) from e

# fill in with asttokens info. note we can use `self._tokens` because
# it is indented to exactly the same position where it appeared
# in the original source!
self._tokens.mark_tokens(fake_node)

# replace the dummy target name with the real target name.
fake_node.target = node.target
# replace the For node target with the new ann_assign
Expand Down
49 changes: 34 additions & 15 deletions vyper/ast/pre_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ class PreParser:
settings: Settings
# A mapping of class names to their original class types.
modification_offsets: dict[tuple[int, int], str]

# Magic adjustments
adjustments: dict[tuple[int, int], int]

# A mapping of line/column offsets of `For` nodes to the annotation of the for loop target
for_loop_annotations: dict[tuple[int, int], list[TokenInfo]]
# A list of line/column offsets of hex string literals
Expand Down Expand Up @@ -196,6 +200,7 @@ def parse(self, code: str):
raise SyntaxException(e.args[0], code, e.args[1][0], e.args[1][1]) from e

def _parse(self, code: str):
adjustments: dict = {}
result: list[TokenInfo] = []
modification_offsets: dict[tuple[int, int], str] = {}
settings = Settings()
Expand All @@ -216,6 +221,12 @@ def _parse(self, code: str):
end = token.end
line = token.line

for tok in toks:
lineno, col = tok.start
adj = _col_adjustments[lineno]
newstart = lineno, col - adj
adjustments[lineno, col - adj] = adj

if typ == COMMENT:
contents = string[1:].strip()
if contents.startswith("@version"):
Expand Down Expand Up @@ -273,15 +284,25 @@ def _parse(self, code: str):

if typ == NAME:
if string in VYPER_CLASS_TYPES and start[1] == 0:
toks = [TokenInfo(NAME, "class", start, end, line)]
modification_offsets[start] = VYPER_CLASS_TYPES[string]
new_keyword = "class"
toks = [TokenInfo(NAME, new_keyword, start, end, line)]

adjustment = len(string) - len(new_keyword)
# adjustments for following tokens
lineno, col = start
_col_adjustments[lineno] += adjustment

modification_offsets[newstart] = VYPER_CLASS_TYPES[string]

elif string in CUSTOM_STATEMENT_TYPES:
new_keyword = "yield"
adjustment = len(new_keyword) - len(string)
# adjustments for following staticcall/extcall modification_offsets
_col_adjustments[start[0]] += adjustment
adjustment = len(string) - len(new_keyword)
# adjustments for following tokens
lineno, col = start
_col_adjustments[lineno] += adjustment
toks = [TokenInfo(NAME, new_keyword, start, end, line)]
modification_offsets[start] = CUSTOM_STATEMENT_TYPES[string]
modification_offsets[newstart] = CUSTOM_STATEMENT_TYPES[string]

elif string in CUSTOM_EXPRESSION_TYPES:
# a bit cursed technique to get untokenize to put
# the new tokens in the right place so that modification_offsets
Expand All @@ -291,18 +312,15 @@ def _parse(self, code: str):
new_keyword = "await"
vyper_type = CUSTOM_EXPRESSION_TYPES[string]

lineno, col_offset = start
adjustment = len(string) - len(new_keyword)
# adjustments for following tokens
lineno, col = start
_col_adjustments[lineno] += adjustment

# fixup for when `extcall/staticcall` follows `log`
adjustment = _col_adjustments[lineno]
new_start = (lineno, col_offset + adjustment)
modification_offsets[new_start] = vyper_type
modification_offsets[newstart] = vyper_type

# tells untokenize to add whitespace, preserving locations
diff = len(new_keyword) - len(string)
new_end = end[0], end[1] + diff

toks = [TokenInfo(NAME, new_keyword, start, new_end, line)]
toks = [TokenInfo(NAME, new_keyword, start, end, line)]

if (typ, string) == (OP, ";"):
raise SyntaxException("Semi-colon statements not allowed", code, start[0], start[1])
Expand All @@ -314,6 +332,7 @@ def _parse(self, code: str):
for k, v in for_parser.annotations.items():
for_loop_annotations[k] = v.copy()

self.adjustments = adjustments
self.settings = settings
self.modification_offsets = modification_offsets
self.for_loop_annotations = for_loop_annotations
Expand Down
Loading