Skip to content

Commit

Permalink
Add types to jinja-related files
Browse files Browse the repository at this point in the history
  • Loading branch information
peterallenwebb committed Sep 7, 2024
1 parent ce09ad3 commit eb57bbe
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 53 deletions.
124 changes: 75 additions & 49 deletions dbt_common/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,21 @@
from collections import ChainMap
from contextlib import contextmanager
from itertools import chain, islice
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Union, Set, Type
from types import CodeType
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Union,
Set,
Type,
NoReturn,
)

from typing_extensions import Protocol

import jinja2
Expand Down Expand Up @@ -39,10 +53,17 @@
SUPPORTED_LANG_ARG = jinja2.nodes.Name("supported_languages", "param")

# Global which can be set by dependents of dbt-common (e.g. core via flag parsing)
MACRO_DEBUGGING = False
MACRO_DEBUGGING: Union[str, bool] = False

_ParseReturn = Union[jinja2.nodes.Node, List[jinja2.nodes.Node]]


# Temporary type capturing the concept the functions in this file expect for a "node"
class _NodeProtocol(Protocol):
pass


def _linecache_inject(source, write):
def _linecache_inject(source: str, write: bool) -> str:
if write:
# this is the only reliable way to accomplish this. Obviously, it's
# really darn noisy and will fill your temporary directory
Expand All @@ -58,18 +79,18 @@ def _linecache_inject(source, write):
else:
# `codecs.encode` actually takes a `bytes` as the first argument if
# the second argument is 'hex' - mypy does not know this.
rnd = codecs.encode(os.urandom(12), "hex") # type: ignore
rnd = codecs.encode(os.urandom(12), "hex")
filename = rnd.decode("ascii")

# put ourselves in the cache
cache_entry = (len(source), None, [line + "\n" for line in source.splitlines()], filename)
# linecache does in fact have an attribute `cache`, thanks
linecache.cache[filename] = cache_entry # type: ignore
linecache.cache[filename] = cache_entry
return filename


class MacroFuzzParser(jinja2.parser.Parser):
def parse_macro(self):
def parse_macro(self) -> jinja2.nodes.Macro:
node = jinja2.nodes.Macro(lineno=next(self.stream).lineno)

# modified to fuzz macros defined in the same file. this way
Expand All @@ -83,16 +104,13 @@ def parse_macro(self):


class MacroFuzzEnvironment(jinja2.sandbox.SandboxedEnvironment):
def _parse(self, source, name, filename):
def _parse(
self, source: str, name: Optional[str], filename: Optional[str]
) -> jinja2.nodes.Template:
return MacroFuzzParser(self, source, name, filename).parse()

def _compile(self, source, filename):
def _compile(self, source: str, filename: str) -> CodeType:
"""
Override jinja's compilation. Use to stash the rendered source inside
the python linecache for debugging when the appropriate environment
variable is set.
Expand All @@ -108,7 +126,7 @@ def _compile(self, source, filename):


class MacroFuzzTemplate(jinja2.nativetypes.NativeTemplate):
environment_class = MacroFuzzEnvironment
environment_class = MacroFuzzEnvironment # type: ignore

def new_context(
self,
Expand Down Expand Up @@ -171,11 +189,11 @@ class NumberMarker(NativeMarker):
pass


def _is_number(value) -> bool:
def _is_number(value: Any) -> bool:
return isinstance(value, (int, float)) and not isinstance(value, bool)


def quoted_native_concat(nodes):
def quoted_native_concat(nodes: Iterator[str]) -> Any:
"""Handle special case for native_concat from the NativeTemplate.
This is almost native_concat from the NativeTemplate, except in the
Expand Down Expand Up @@ -213,7 +231,7 @@ def quoted_native_concat(nodes):
class NativeSandboxTemplate(jinja2.nativetypes.NativeTemplate): # mypy: ignore
environment_class = NativeSandboxEnvironment # type: ignore

def render(self, *args, **kwargs):
def render(self, *args: Any, **kwargs: Any) -> Any:
"""Render the template to produce a native Python type.
If the result is a single node, its value is returned. Otherwise,
Expand All @@ -229,14 +247,19 @@ def render(self, *args, **kwargs):
return self.environment.handle_exception()


class MacroProtocol(Protocol):
name: str
macro_sql: str


NativeSandboxEnvironment.template_class = NativeSandboxTemplate # type: ignore


class TemplateCache:
def __init__(self) -> None:
self.file_cache: Dict[str, jinja2.Template] = {}

def get_node_template(self, node) -> jinja2.Template:
def get_node_template(self, node: MacroProtocol) -> jinja2.Template:
key = node.macro_sql

if key in self.file_cache:
Expand All @@ -251,7 +274,7 @@ def get_node_template(self, node) -> jinja2.Template:
self.file_cache[key] = template
return template

def clear(self):
def clear(self) -> None:
self.file_cache.clear()


Expand All @@ -262,13 +285,13 @@ class BaseMacroGenerator:
def __init__(self, context: Optional[Dict[str, Any]] = None) -> None:
self.context: Optional[Dict[str, Any]] = context

def get_template(self):
def get_template(self) -> jinja2.Template:
raise NotImplementedError("get_template not implemented!")

def get_name(self) -> str:
raise NotImplementedError("get_name not implemented!")

def get_macro(self):
def get_macro(self) -> Callable:
name = self.get_name()
template = self.get_template()
# make the module. previously we set both vars and local, but that's
Expand All @@ -286,7 +309,7 @@ def exception_handler(self) -> Iterator[None]:
except (TypeError, jinja2.exceptions.TemplateRuntimeError) as e:
raise CaughtMacroError(e)

def call_macro(self, *args, **kwargs):
def call_macro(self, *args: Any, **kwargs: Any) -> Any:
# called from __call__ methods
if self.context is None:
raise DbtInternalError("Context is still None in call_macro!")
Expand All @@ -301,11 +324,6 @@ def call_macro(self, *args, **kwargs):
return e.value


class MacroProtocol(Protocol):
name: str
macro_sql: str


class CallableMacroGenerator(BaseMacroGenerator):
def __init__(
self,
Expand All @@ -315,7 +333,7 @@ def __init__(
super().__init__(context)
self.macro = macro

def get_template(self):
def get_template(self) -> jinja2.Template:
return template_cache.get_node_template(self.macro)

def get_name(self) -> str:
Expand All @@ -332,14 +350,14 @@ def exception_handler(self) -> Iterator[None]:
raise e

# this makes MacroGenerator objects callable like functions
def __call__(self, *args, **kwargs):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.call_macro(*args, **kwargs)


class MaterializationExtension(jinja2.ext.Extension):
tags = ["materialization"]

def parse(self, parser):
def parse(self, parser: jinja2.parser.Parser) -> _ParseReturn:
node = jinja2.nodes.Macro(lineno=next(parser.stream).lineno)
materialization_name = parser.parse_assign_target(name_only=True).name

Expand Down Expand Up @@ -382,7 +400,7 @@ def parse(self, parser):
class DocumentationExtension(jinja2.ext.Extension):
tags = ["docs"]

def parse(self, parser):
def parse(self, parser: jinja2.parser.Parser) -> _ParseReturn:
node = jinja2.nodes.Macro(lineno=next(parser.stream).lineno)
docs_name = parser.parse_assign_target(name_only=True).name

Expand All @@ -396,7 +414,7 @@ def parse(self, parser):
class TestExtension(jinja2.ext.Extension):
tags = ["test"]

def parse(self, parser):
def parse(self, parser: jinja2.parser.Parser) -> _ParseReturn:
node = jinja2.nodes.Macro(lineno=next(parser.stream).lineno)
test_name = parser.parse_assign_target(name_only=True).name

Expand All @@ -406,13 +424,19 @@ def parse(self, parser):
return node


def _is_dunder_name(name):
def _is_dunder_name(name: str) -> bool:
return name.startswith("__") and name.endswith("__")


def create_undefined(node=None):
def create_undefined(node: Optional[_NodeProtocol] = None) -> Type[jinja2.Undefined]:
class Undefined(jinja2.Undefined):
def __init__(self, hint=None, obj=None, name=None, exc=None):
def __init__(
self,
hint: Optional[str] = None,
obj: Any = None,
name: Optional[str] = None,
exc: Any = None,
) -> None:
super().__init__(hint=hint, name=name)
self.node = node
self.name = name
Expand All @@ -422,12 +446,12 @@ def __init__(self, hint=None, obj=None, name=None, exc=None):
self.unsafe_callable = False
self.alters_data = False

def __getitem__(self, name):
def __getitem__(self, name: Any) -> "Undefined":
# Propagate the undefined value if a caller accesses this as if it
# were a dictionary
return self

def __getattr__(self, name):
def __getattr__(self, name: str) -> "Undefined":
if name == "name" or _is_dunder_name(name):
raise AttributeError(
"'{}' object has no attribute '{}'".format(type(self).__name__, name)
Expand All @@ -437,11 +461,11 @@ def __getattr__(self, name):

return self.__class__(hint=self.hint, name=self.name)

def __call__(self, *args, **kwargs):
def __call__(self, *args: Any, **kwargs: Any) -> "Undefined":
return self

def __reduce__(self):
raise UndefinedCompilationError(name=self.name, node=node)
def __reduce__(self) -> NoReturn:
raise UndefinedCompilationError(name=self.name or "unknown", node=node)

return Undefined

Expand All @@ -463,7 +487,7 @@ def __reduce__(self):


def get_environment(
node=None,
node: Optional[_NodeProtocol] = None,
capture_macros: bool = False,
native: bool = False,
) -> jinja2.Environment:
Expand All @@ -472,7 +496,7 @@ def get_environment(
}

if capture_macros:
args["undefined"] = create_undefined(node)
args["undefined"] = create_undefined(node) # type: ignore

args["extensions"].append(MaterializationExtension)
args["extensions"].append(DocumentationExtension)
Expand All @@ -493,7 +517,7 @@ def get_environment(


@contextmanager
def catch_jinja(node=None) -> Iterator[None]:
def catch_jinja(node: Optional[_NodeProtocol] = None) -> Iterator[None]:
try:
yield
except jinja2.exceptions.TemplateSyntaxError as e:
Expand All @@ -506,16 +530,16 @@ def catch_jinja(node=None) -> Iterator[None]:
raise


_TESTING_PARSE_CACHE: Dict[str, jinja2.Template] = {}
_TESTING_PARSE_CACHE: Dict[str, jinja2.nodes.Template] = {}


def parse(string):
def parse(string: Any) -> jinja2.nodes.Template:
str_string = str(string)
if test_caching_enabled() and str_string in _TESTING_PARSE_CACHE:
return _TESTING_PARSE_CACHE[str_string]

with catch_jinja():
parsed = get_environment().parse(str(string))
parsed: jinja2.nodes.Template = get_environment().parse(str(string))
if test_caching_enabled():
_TESTING_PARSE_CACHE[str_string] = parsed
return parsed
Expand All @@ -524,18 +548,20 @@ def parse(string):
def get_template(
string: str,
ctx: Dict[str, Any],
node=None,
node: Optional[_NodeProtocol] = None,
capture_macros: bool = False,
native: bool = False,
):
) -> jinja2.Template:
with catch_jinja(node):
env = get_environment(node, capture_macros, native=native)

template_source = str(string)
return env.from_string(template_source, globals=ctx)


def render_template(template, ctx: Dict[str, Any], node=None) -> str:
def render_template(
template: jinja2.Template, ctx: Dict[str, Any], node: Optional[_NodeProtocol] = None
) -> str:
with catch_jinja(node):
return template.render(ctx)

Expand Down
9 changes: 5 additions & 4 deletions dbt_common/exceptions/jinja.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from dbt_common.clients._jinja_blocks import Tag, TagIterator
from dbt_common.exceptions import CompilationError


class BlockDefinitionNotAtTopError(CompilationError):
def __init__(self, tag_parser, tag_start) -> None:
def __init__(self, tag_parser: TagIterator, tag_start: int) -> None:
self.tag_parser = tag_parser
self.tag_start = tag_start
super().__init__(msg=self.get_message())
Expand Down Expand Up @@ -31,7 +32,7 @@ def get_message(self) -> str:


class MissingControlFlowStartTagError(CompilationError):
def __init__(self, tag, expected_tag: str, tag_parser) -> None:
def __init__(self, tag: Tag, expected_tag: str, tag_parser: TagIterator) -> None:
self.tag = tag
self.expected_tag = expected_tag
self.tag_parser = tag_parser
Expand All @@ -47,7 +48,7 @@ def get_message(self) -> str:


class NestedTagsError(CompilationError):
def __init__(self, outer, inner) -> None:
def __init__(self, outer: Tag, inner: Tag) -> None:
self.outer = outer
self.inner = inner
super().__init__(msg=self.get_message())
Expand All @@ -62,7 +63,7 @@ def get_message(self) -> str:


class UnexpectedControlFlowEndTagError(CompilationError):
def __init__(self, tag, expected_tag: str, tag_parser) -> None:
def __init__(self, tag: Tag, expected_tag: str, tag_parser: TagIterator) -> None:
self.tag = tag
self.expected_tag = expected_tag
self.tag_parser = tag_parser
Expand Down

0 comments on commit eb57bbe

Please sign in to comment.