Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fix #708 #1025

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
72 changes: 72 additions & 0 deletions njit_fastmath.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pathlib

from utils import check_callees, check_functions

stumpy_path = pathlib.Path(__file__).parent / "stumpy"
filepaths = sorted(f for f in pathlib.Path(stumpy_path).iterdir() if f.is_file())

all_functions = {}

ignore = ["__init__.py", "__pycache__"]
for filepath in filepaths:
file_name = filepath.name
if file_name not in ignore and str(filepath).endswith(".py"):
NimaSarajpoor marked this conversation as resolved.
Show resolved Hide resolved
prefix = file_name.replace(".py", "")
NimaSarajpoor marked this conversation as resolved.
Show resolved Hide resolved

func_names, is_njit, fastmath_values = check_functions(filepath)
func_names = [f"{prefix}.{fn}" for fn in func_names]

all_functions[file_name] = {
"func_names": func_names,
"is_njit": is_njit,
"fastmath_values": fastmath_values,
}

all_stumpy_functions = set()
for file_name, file_functions_metadata in all_functions.items():
all_stumpy_functions.update(file_functions_metadata["func_names"])

all_stumpy_functions = list(all_stumpy_functions)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels strange. You declare all_stumpy_functions as a set() above and then you update it and then you convert it into a list all within a for-loop.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing that I forgot to notice is that the nested dictionary all_functions has unique values for all_functions['file_name']['func_names"]. So, I can just use list.

all_stumpy_functions = []
for file_name in all_functions.keys():
    all_stumpy_functions.extend(all_functions['file_name']['func_names'])

Does that address your concern?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not entirely clear what the goal is here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Every time I look at it... I feel like the nested dictionary makes it hard to read. I need to re-design it.

It's not entirely clear what the goal is here

My goal was to create a comprehensive list that contains all the functions that are defined in ./stumpy/, and then use it to ignore any callee that is NOT in that comprehensive list. For instance, range can be a callee of a function. However, I do not want to keep it. I want to only have the callers/ callees that are defined in stumpy.

all_stumpy_functions_no_prefix = [f.split(".")[-1] for f in all_stumpy_functions]


# output 1: func_metadata
func_metadata = {}
for file_name, file_functions_metadata in all_functions.items():
for i, f in enumerate(file_functions_metadata["func_names"]):
is_njit = file_functions_metadata["is_njit"][i]
fastmath_value = file_functions_metadata["fastmath_values"][i]
func_metadata[f] = [is_njit, fastmath_value]


# output 2: func_callers
func_callers = {}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seanlaw
Previously, I wanted to follow the chain of callees. Now, as you may notice here, I am trying to get list of callers for each function. We can start with those functions that are njit-decorated and have fastmath flag, and then follow the chain of callers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Whatever you think might work

for f in func_metadata.keys():
func_callers[f] = []

for filepath in filepaths:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Going to confess here... I think there is a code smell here because of nested for-loop. Although I tried a couple of times, I think I still need to work on this part as it seems to be complicated.

file_name = filepath.name
if file_name in ignore or not str(filepath).endswith(".py"):
continue

prefix = file_name.replace(".py", "")
callees = check_callees(filepath)

current_callers = set(callees.keys())
for caller, callee_set in callees.items():
s = list(callee_set.intersection(all_stumpy_functions_no_prefix))
if len(s) == 0:
continue

for c in s:
if c in current_callers:
c_name = prefix + "." + c
else:
idx = all_stumpy_functions_no_prefix.index(c)
c_name = all_stumpy_functions[idx]

func_callers[c_name].append(f"{prefix}.{caller}")


for f, callers in func_callers.items():
func_callers[f] = list(set(callers))
146 changes: 146 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import ast


def check_fastmath(decorator):
"""
For the given `decorator` node with type `ast.Call`,
return the value of the `fastmath` argument if it exists.
Otherwise, return `None`.
"""
fastmath_value = None
for n in ast.iter_child_nodes(decorator):
if isinstance(n, ast.keyword) and n.arg == "fastmath":
if isinstance(n.value, ast.Constant):
fastmath_value = n.value.value
elif isinstance(n.value, ast.Set):
fastmath_value = set(item.value for item in n.value.elts)
else:
pass
break

return fastmath_value


def check_njit(fd):
"""
For the given `fd` node with type `ast.FunctionDef`,
return the node of the `njit` decorator if it exists.
Otherwise, return `None`.
"""
decorator_node = None
for decorator in fd.decorator_list:
if not isinstance(decorator, ast.Call):
continue

obj = decorator.func
if isinstance(obj, ast.Attribute):
name = obj.attr
elif isinstance(obj, ast.Subscript):
name = obj.value.id
elif isinstance(obj, ast.Name):
name = obj.id
else:
msg = f"The type {type(obj)} is not supported."
raise ValueError(msg)

if name == "njit":
decorator_node = decorator
break

return decorator_node


def check_functions(filepath):
"""
For the given `filepath`, return the function names,
whether the function is decorated with `@njit`,
and the value of the `fastmath` argument if it exists

Parameters
----------
filepath : str
The path to the file

Returns
-------
func_names : list
List of function names

is_njit : list
List of boolean values indicating whether the function is decorated with `@njit`

fastmath_value : list
List of values of the `fastmath` argument if it exists
"""
file_contents = ""
with open(filepath, encoding="utf8") as f:
file_contents = f.read()
module = ast.parse(file_contents)

function_definitions = [
node for node in module.body if isinstance(node, ast.FunctionDef)
]

func_names = [fd.name for fd in function_definitions]

njit_nodes = [check_njit(fd) for fd in function_definitions]
is_njit = [node is not None for node in njit_nodes]

fastmath_values = [None] * len(njit_nodes)
for i, node in enumerate(njit_nodes):
if node is not None:
fastmath_values[i] = check_fastmath(node)

return func_names, is_njit, fastmath_values


def _get_callees(node, all_functions):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Note to self]
One thing I need to consider here is how things will work if a callee is assigned to a variable before being used. For instance:

a_subseq_isconstant = _rolling_isconstant

for n in ast.iter_child_nodes(node):
if isinstance(n, ast.Call):
obj = n.func
if isinstance(obj, ast.Attribute):
name = obj.attr
elif isinstance(obj, ast.Subscript):
name = obj.value.id
elif isinstance(obj, ast.Name):
name = obj.id
else:
msg = f"The type {type(obj)} is not supported"
raise ValueError(msg)

all_functions.append(name)

_get_callees(n, all_functions)


def get_all_callees(fd):
"""
For a given node of type ast.FunctionDef, visit all of its child nodes,
and return a list of all of its callees
"""
all_functions = []
_get_callees(fd, all_functions)

return all_functions


def check_callees(filepath):
"""
For the given `filepath`, return a dictionary with the key
being the function name and the value being a set of function names
that are called by the function
"""
file_contents = ""
with open(filepath, encoding="utf8") as f:
file_contents = f.read()
module = ast.parse(file_contents)

function_definitions = [
node for node in module.body if isinstance(node, ast.FunctionDef)
]

callees = {}
for fd in function_definitions:
callees[fd.name] = set(get_all_callees(fd))

return callees
Loading