Skip to content

Commit

Permalink
Fixed #1059 missing fastmath (#1060)
Browse files Browse the repository at this point in the history
* add function to  check fastmath

* revise fastmath script and add support for reading arg from command line

* minor fixes

* rename function to improve readability

* simplify code by passing boolean value

* fix to catch njit functions with decorator

* use regex to find njit functions

* minor change

* revise code to detect bare njit decorator

* minor fix

* fix path

* add missing fastmath

* revise fastmath flag

* Improve ValueError msg

* fix format

* enable function to accept path as input

* pass param via CLI, and some minor changes

* adapt changes in test script

* use type str for the param pkg_dir

* minor changes

* Revised string concatenation in error message
  • Loading branch information
NimaSarajpoor authored Jan 12, 2025
1 parent ce0cd8c commit 70e4e70
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 12 deletions.
100 changes: 100 additions & 0 deletions fastmath.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#!/usr/bin/env python

import argparse
import ast
import importlib
import pathlib


def get_njit_funcs(pkg_dir):
"""
Identify all njit functions
Parameters
----------
pkg_dir : str
The path to the directory containing some .py files
Returns
-------
njit_funcs : list
A list of all njit functions, where each element is a tuple of the form
(module_name, func_name)
"""
ignore_py_files = ["__init__", "__pycache__"]
pkg_dir = pathlib.Path(pkg_dir)

module_names = []
for fname in pkg_dir.iterdir():
if fname.stem not in ignore_py_files and not fname.stem.startswith("."):
module_names.append(fname.stem)

njit_funcs = []
for module_name in module_names:
filepath = pkg_dir / f"{module_name}.py"
file_contents = ""
with open(filepath, encoding="utf8") as f:
file_contents = f.read()
module = ast.parse(file_contents)
for node in module.body:
if isinstance(node, ast.FunctionDef):
func_name = node.name
for decorator in node.decorator_list:
decorator_name = None
if isinstance(decorator, ast.Name):
# Bare decorator
decorator_name = decorator.id
if isinstance(decorator, ast.Call) and isinstance(
decorator.func, ast.Name
):
# Decorator is a function
decorator_name = decorator.func.id

if decorator_name == "njit":
njit_funcs.append((module_name, func_name))

return njit_funcs


def check_fastmath(pkg_dir, pkg_name):
"""
Check if all njit functions have the `fastmath` flag set
Parameters
----------
pkg_dir : str
The path to the directory containing some .py files
pkg_name : str
The name of the package
Returns
-------
None
"""
missing_fastmath = [] # list of njit functions with missing fastmath flags
for module_name, func_name in get_njit_funcs(pkg_dir):
module = importlib.import_module(f".{module_name}", package=pkg_name)
func = getattr(module, func_name)
if "fastmath" not in func.targetoptions.keys():
missing_fastmath.append(f"{module_name}.{func_name}")

if len(missing_fastmath) > 0:
msg = (
"Found one or more `@njit` functions that are missing the `fastmath` flag. "
+ f"The functions are:\n {missing_fastmath}\n"
)
raise ValueError(msg)

return


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--check", dest="pkg_dir")
args = parser.parse_args()

if args.pkg_dir:
pkg_dir = pathlib.Path(args.pkg_dir)
pkg_name = pkg_dir.name
check_fastmath(str(pkg_dir), pkg_name)
22 changes: 16 additions & 6 deletions stumpy/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import ast
import importlib
import pathlib
import pkgutil
import site
import warnings

Expand All @@ -28,13 +27,17 @@ def get_njit_funcs():
out : list
A list of (`module_name`, `func_name`) pairs
"""
ignore_py_files = ["__init__", "__pycache__"]

pkg_dir = pathlib.Path(__file__).parent
module_names = [name for _, name, _ in pkgutil.iter_modules([str(pkg_dir)])]
module_names = []
for fname in pkg_dir.iterdir():
if fname.stem not in ignore_py_files and not fname.stem.startswith("."):
module_names.append(fname.stem)

njit_funcs = []

for module_name in module_names:
filepath = pathlib.Path(__file__).parent / f"{module_name}.py"
filepath = pkg_dir / f"{module_name}.py"
file_contents = ""
with open(filepath, encoding="utf8") as f:
file_contents = f.read()
Expand All @@ -43,11 +46,18 @@ def get_njit_funcs():
if isinstance(node, ast.FunctionDef):
func_name = node.name
for decorator in node.decorator_list:
decorator_name = None
if isinstance(decorator, ast.Name):
# Bare decorator
decorator_name = decorator.id
if isinstance(decorator, ast.Call) and isinstance(
decorator.func, ast.Name
):
if decorator.func.id == "njit":
njit_funcs.append((module_name, func_name))
# Decorator is a function
decorator_name = decorator.func.id

if decorator_name == "njit":
njit_funcs.append((module_name, func_name))

return njit_funcs

Expand Down
10 changes: 6 additions & 4 deletions stumpy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2356,6 +2356,7 @@ def _count_diagonal_ndist(diags, m, n_A, n_B):

@njit(
# "i8[:, :](i8[:], i8, b1)"
fastmath=True
)
def _get_array_ranges(a, n_chunks, truncate):
"""
Expand Down Expand Up @@ -2404,6 +2405,7 @@ def _get_array_ranges(a, n_chunks, truncate):

@njit(
# "i8[:, :](i8, i8, b1)"
fastmath=True
)
def _get_ranges(size, n_chunks, truncate):
"""
Expand Down Expand Up @@ -3256,7 +3258,7 @@ def _select_P_ABBA_value(P_ABBA, k, custom_func=None):
return MPdist


@njit()
@njit(fastmath={"nsz", "arcp", "contract", "afn", "reassoc"})
def _merge_topk_PI(PA, PB, IA, IB):
"""
Merge two top-k matrix profiles `PA` and `PB`, and update `PA` (in place).
Expand Down Expand Up @@ -3329,7 +3331,7 @@ def _merge_topk_PI(PA, PB, IA, IB):
IA[i] = tmp_I


@njit()
@njit(fastmath={"nsz", "arcp", "contract", "afn", "reassoc"})
def _merge_topk_ρI(ρA, ρB, IA, IB):
"""
Merge two top-k pearson profiles `ρA` and `ρB`, and update `ρA` (in place).
Expand Down Expand Up @@ -3403,7 +3405,7 @@ def _merge_topk_ρI(ρA, ρB, IA, IB):
IA[i] = tmp_I


@njit()
@njit(fastmath={"nsz", "arcp", "contract", "afn", "reassoc"})
def _shift_insert_at_index(a, idx, v, shift="right"):
"""
If `shift=right` (default), all elements in `a[idx:]` are shifted to the right by
Expand Down Expand Up @@ -4379,7 +4381,7 @@ def get_ray_nworkers(ray_client):
return int(ray_client.cluster_resources().get("CPU"))


@njit
@njit(fastmath={"nsz", "arcp", "contract", "afn", "reassoc"})
def _update_incremental_PI(D, P, I, excl_zone, n_appended=0):
"""
Given the 1D array distance profile, `D`, of the last subsequence of T,
Expand Down
17 changes: 15 additions & 2 deletions test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ check_print()
fi
}

check_fastmath()
{
echo "Checking Missing fastmath flags in njit functions"
./fastmath.py --check stumpy
check_errs $?
}

check_naive()
{
# Check if there are any naive implementations not at start of test file
Expand Down Expand Up @@ -146,14 +153,14 @@ set_ray_coveragerc()
show_coverage_report()
{
set_ray_coveragerc
coverage report -m --fail-under=100 --skip-covered --omit=docstring.py,min_versions.py,ray_python_version.py,stumpy/cache.py $fcoveragerc
coverage report -m --fail-under=100 --skip-covered --omit=fastmath.py,docstring.py,min_versions.py,ray_python_version.py,stumpy/cache.py $fcoveragerc
}

gen_coverage_xml_report()
{
# This function saves the coverage report in Cobertura XML format, which is compatible with codecov
set_ray_coveragerc
coverage xml -o $fcoveragexml --fail-under=100 --omit=docstring.py,min_versions.py,ray_python_version.py,stumpy/cache.py $fcoveragerc
coverage xml -o $fcoveragexml --fail-under=100 --omit=fastmath.py,docstring.py,min_versions.py,ray_python_version.py,stumpy/cache.py $fcoveragerc
}

test_custom()
Expand Down Expand Up @@ -333,6 +340,12 @@ check_print
check_naive
check_ray


if [[ -z $NUMBA_DISABLE_JIT || $NUMBA_DISABLE_JIT -eq 0 ]]; then
check_fastmath
fi


if [[ $test_mode == "notebooks" ]]; then
echo "Executing Tutorial Notebooks Only"
convert_notebooks
Expand Down

0 comments on commit 70e4e70

Please sign in to comment.