Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthewMiddlehurst committed Aug 5, 2024
1 parent 243f87e commit 092fa42
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 126 deletions.
39 changes: 22 additions & 17 deletions aeon/base/_base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,23 +233,28 @@ def _check_shape(self, X):
X : data structure
Must be of type aeon.registry.COLLECTIONS_DATA_TYPES.
"""
if not self.get_tag("capability:unequal_length"):
nt = get_n_timepoints(X)
if nt != self.metadata_["n_timepoints"]:
raise ValueError(
"X has different length to the data seen in fit but "
"this classifier cannot handle unequal length series."
f"length of train set was {self.metadata_['n_timepoints']}",
f" length in predict is {nt}.",
)
if self.get_tag("capability:multivariate"):
nc = get_n_channels(X)
if nc != self.metadata_["n_channels"]:
raise ValueError(
"X has different number of channels to the data seen in fit "
"number of channels in train set was ",
f"{self.metadata_['n_channels']} but in predict it is {nc}.",
)
# if metadata is empty, then we have not seen any data in fit. If the estimator
# has not been fitted, then _is_fitted should catch this.
# there are valid cases where metadata is empty and the estimator has been
# fitted, i.e. deep learner loading.
if len(self.metadata_) != 0:
if not self.get_tag("capability:unequal_length"):
nt = get_n_timepoints(X)
if nt != self.metadata_["n_timepoints"]:
raise ValueError(
"X has different length to the data seen in fit but "
"this classifier cannot handle unequal length series."
f"length of train set was {self.metadata_['n_timepoints']}",
f" length in predict is {nt}.",
)
if self.get_tag("capability:multivariate"):
nc = get_n_channels(X)
if nc != self.metadata_["n_channels"]:
raise ValueError(
"X has different number of channels to the data seen in fit "
"number of channels in train set was ",
f"{self.metadata_['n_channels']} but in predict it is {nc}.",
)

@staticmethod
def _get_X_metadata(X):
Expand Down
1 change: 1 addition & 0 deletions aeon/classification/dictionary_based/_boss.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,7 @@ def _shorten_bags(self, word_len, y):
new_boss.n_classes_ = self.n_classes_
new_boss.classes_ = self.classes_
new_boss._class_dictionary = self._class_dictionary
new_boss.metadata_ = self.metadata_
new_boss._is_fitted = True

return new_boss
Expand Down
4 changes: 1 addition & 3 deletions aeon/testing/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import os

import aeon.testing.utils._cicd_numba_caching # noqa: F401
from aeon.base import (
BaseCollectionEstimator,
BaseEstimator,
Expand All @@ -17,9 +18,6 @@
# per os/version default is False, can be set to True by pytest --prtesting True flag
PR_TESTING = False

if os.environ.get("CICD_RUNNING") == "1": # pragma: no cover
import aeon.testing.utils._cicd_numba_caching # noqa: F401

EXCLUDE_ESTIMATORS = [
"TabularToSeriesAdaptor",
"PandasTransformAdaptor",
Expand Down
208 changes: 102 additions & 106 deletions aeon/testing/utils/_cicd_numba_caching.py
Original file line number Diff line number Diff line change
@@ -1,111 +1,107 @@
import os
import pickle
import subprocess

import numba.core.caching


def get_invalid_numba_files():
"""Get the files that have been changed since the last commit.
This function is used to get the files that have been changed. This is needed
because custom logic to save the numba cache has been implemented for numba.
This function returns the file names that have been changed and if they appear
in here any numba functions cache are invalidated.
Returns
-------
list
List of file names that have been changed.
"""
subprocess.run(["git", "fetch", "origin", "main"], check=True)

result = subprocess.run(
["git", "diff", "--name-only", "origin/main"],
check=True,
capture_output=True,
text=True, # Makes the output text instead of bytes
)

files = result.stdout.split("\n")

files = [file for file in files if file]

clean_files = []

for temp in files:
if temp.endswith(".py"):
clean_files.append((temp.split("/")[-1]).strip(".py"))

return clean_files


# Retry the git fetch and git diff commands in case of failure
retry = 0
while retry < 3:
try:
CHANGED_FILES = get_invalid_numba_files()
break
except subprocess.CalledProcessError:
retry += 1

# If the retry count is reached, raise an error
if retry == 3:
raise Exception("Failed to get the changed files from git.")


def custom_load_index(self):
"""Overwrite load index method for numba.
This is used to overwrite the numba internal logic to allow for caching during
the cicd run. Numba traditionally checks the timestamp of the file and if it
has changed it invalidates the cache. This is not ideal for the cicd run as
the cache restore is always before the files (since it is cloned in) and
thus the cache is always invalidated. This custom method ignores the timestamp
element and instead just checks the file name. This isn't as fine grained as numba
but it is better to invalidate more and make sure the right function has been
compiled than try to be too clever and miss some.
Returns
-------
dict
Dictionary of the cached functions.
"""
try:
with open(self._index_path, "rb") as f:
version = pickle.load(f)
data = f.read()
except FileNotFoundError:
return {}
if version != self._version:
return {}
stamp, overloads = pickle.loads(data)
cache_filename = self._index_path.split("/")[-1].split("-")[0].split(".")[0]
if stamp[1] != self._source_stamp[1] or cache_filename in CHANGED_FILES:
return {}
else:
return overloads


original_load_index = numba.core.caching.IndexDataCacheFile._load_index
numba.core.caching.IndexDataCacheFile._load_index = custom_load_index


# Force all numba functions to be cached
original_jit = numba.core.decorators._jit


def custom_njit(*args, **kwargs):
"""Force jit to cache.
This is used for libraries like stumpy that doesn't cache by default. This
function will force all functions running to be cache'd
"""
target = kwargs["targetoptions"]
# This target can't be cached
if "no_cpython_wrapper" not in target:
kwargs["cache"] = True
return original_jit(*args, **kwargs)


# Overwrite the jit function with the custom version
numba.core.decorators._jit = custom_njit
if os.environ.get("CICD_RUNNING") == "1": # pragma: no cover

def get_invalid_numba_files():
"""Get the files that have been changed since the last commit.
This function is used to get the files that have been changed. This is needed
because custom logic to save the numba cache has been implemented for numba.
This function returns the file names that have been changed and if they appear
in here any numba functions cache are invalidated.
Returns
-------
list
List of file names that have been changed.
"""
subprocess.run(["git", "fetch", "origin", "main"], check=True)

result = subprocess.run(
["git", "diff", "--name-only", "origin/main"],
check=True,
capture_output=True,
text=True, # Makes the output text instead of bytes
)

files = result.stdout.split("\n")

files = [file for file in files if file]

clean_files = []

for temp in files:
if temp.endswith(".py"):
clean_files.append((temp.split("/")[-1]).strip(".py"))

return clean_files

# Retry the git fetch and git diff commands in case of failure
retry = 0
while retry < 3:
try:
CHANGED_FILES = get_invalid_numba_files()
break
except subprocess.CalledProcessError:
retry += 1

# If the retry count is reached, raise an error
if retry == 3:
raise Exception("Failed to get the changed files from git.")

def custom_load_index(self):
"""Overwrite load index method for numba.
This is used to overwrite the numba internal logic to allow for caching during
the cicd run. Numba traditionally checks the timestamp of the file and if it
has changed it invalidates the cache. This is not ideal for the cicd run as
the cache restore is always before the files (since it is cloned in) and
thus the cache is always invalidated. This custom method ignores the timestamp
element and instead just checks the file name. This isn't as fine grained as
numba but it is better to invalidate more and make sure the right function has
been compiled than try to be too clever and miss some.
Returns
-------
dict
Dictionary of the cached functions.
"""
try:
with open(self._index_path, "rb") as f:
version = pickle.load(f)
data = f.read()
except FileNotFoundError:
return {}
if version != self._version:
return {}
stamp, overloads = pickle.loads(data)
cache_filename = self._index_path.split("/")[-1].split("-")[0].split(".")[0]
if stamp[1] != self._source_stamp[1] or cache_filename in CHANGED_FILES:
return {}
else:
return overloads

original_load_index = numba.core.caching.IndexDataCacheFile._load_index
numba.core.caching.IndexDataCacheFile._load_index = custom_load_index

# Force all numba functions to be cached
original_jit = numba.core.decorators._jit

def custom_njit(*args, **kwargs):
"""Force jit to cache.
This is used for libraries like stumpy that doesn't cache by default. This
function will force all functions running to be cache'd
"""
target = kwargs["targetoptions"]
# This target can't be cached
if "no_cpython_wrapper" not in target:
kwargs["cache"] = True
return original_jit(*args, **kwargs)

# Overwrite the jit function with the custom version
numba.core.decorators._jit = custom_njit

0 comments on commit 092fa42

Please sign in to comment.