From 092fa4298ddcd5d69a8e66b8fdb08c7fd4fde2f4 Mon Sep 17 00:00:00 2001 From: MatthewMiddlehurst Date: Mon, 5 Aug 2024 15:39:54 +0100 Subject: [PATCH] fixes --- aeon/base/_base_collection.py | 39 ++-- aeon/classification/dictionary_based/_boss.py | 1 + aeon/testing/test_config.py | 4 +- aeon/testing/utils/_cicd_numba_caching.py | 208 +++++++++--------- 4 files changed, 126 insertions(+), 126 deletions(-) diff --git a/aeon/base/_base_collection.py b/aeon/base/_base_collection.py index c1beb78c0e..3c572d13a5 100644 --- a/aeon/base/_base_collection.py +++ b/aeon/base/_base_collection.py @@ -233,23 +233,28 @@ def _check_shape(self, X): X : data structure Must be of type aeon.registry.COLLECTIONS_DATA_TYPES. """ - if not self.get_tag("capability:unequal_length"): - nt = get_n_timepoints(X) - if nt != self.metadata_["n_timepoints"]: - raise ValueError( - "X has different length to the data seen in fit but " - "this classifier cannot handle unequal length series." - f"length of train set was {self.metadata_['n_timepoints']}", - f" length in predict is {nt}.", - ) - if self.get_tag("capability:multivariate"): - nc = get_n_channels(X) - if nc != self.metadata_["n_channels"]: - raise ValueError( - "X has different number of channels to the data seen in fit " - "number of channels in train set was ", - f"{self.metadata_['n_channels']} but in predict it is {nc}.", - ) + # if metadata is empty, then we have not seen any data in fit. If the estimator + # has not been fitted, then _is_fitted should catch this. + # there are valid cases where metadata is empty and the estimator has been + # fitted, i.e. deep learner loading. + if len(self.metadata_) != 0: + if not self.get_tag("capability:unequal_length"): + nt = get_n_timepoints(X) + if nt != self.metadata_["n_timepoints"]: + raise ValueError( + "X has different length to the data seen in fit but " + "this classifier cannot handle unequal length series." + f"length of train set was {self.metadata_['n_timepoints']}", + f" length in predict is {nt}.", + ) + if self.get_tag("capability:multivariate"): + nc = get_n_channels(X) + if nc != self.metadata_["n_channels"]: + raise ValueError( + "X has different number of channels to the data seen in fit " + "number of channels in train set was ", + f"{self.metadata_['n_channels']} but in predict it is {nc}.", + ) @staticmethod def _get_X_metadata(X): diff --git a/aeon/classification/dictionary_based/_boss.py b/aeon/classification/dictionary_based/_boss.py index 2aca789555..b88b5e583a 100644 --- a/aeon/classification/dictionary_based/_boss.py +++ b/aeon/classification/dictionary_based/_boss.py @@ -658,6 +658,7 @@ def _shorten_bags(self, word_len, y): new_boss.n_classes_ = self.n_classes_ new_boss.classes_ = self.classes_ new_boss._class_dictionary = self._class_dictionary + new_boss.metadata_ = self.metadata_ new_boss._is_fitted = True return new_boss diff --git a/aeon/testing/test_config.py b/aeon/testing/test_config.py index 2eae519bc3..dd3a47dadb 100644 --- a/aeon/testing/test_config.py +++ b/aeon/testing/test_config.py @@ -5,6 +5,7 @@ import os +import aeon.testing.utils._cicd_numba_caching # noqa: F401 from aeon.base import ( BaseCollectionEstimator, BaseEstimator, @@ -17,9 +18,6 @@ # per os/version default is False, can be set to True by pytest --prtesting True flag PR_TESTING = False -if os.environ.get("CICD_RUNNING") == "1": # pragma: no cover - import aeon.testing.utils._cicd_numba_caching # noqa: F401 - EXCLUDE_ESTIMATORS = [ "TabularToSeriesAdaptor", "PandasTransformAdaptor", diff --git a/aeon/testing/utils/_cicd_numba_caching.py b/aeon/testing/utils/_cicd_numba_caching.py index 4aad4b49c2..f3afd4d755 100644 --- a/aeon/testing/utils/_cicd_numba_caching.py +++ b/aeon/testing/utils/_cicd_numba_caching.py @@ -1,111 +1,107 @@ +import os import pickle import subprocess import numba.core.caching - -def get_invalid_numba_files(): - """Get the files that have been changed since the last commit. - - This function is used to get the files that have been changed. This is needed - because custom logic to save the numba cache has been implemented for numba. - This function returns the file names that have been changed and if they appear - in here any numba functions cache are invalidated. - - Returns - ------- - list - List of file names that have been changed. - """ - subprocess.run(["git", "fetch", "origin", "main"], check=True) - - result = subprocess.run( - ["git", "diff", "--name-only", "origin/main"], - check=True, - capture_output=True, - text=True, # Makes the output text instead of bytes - ) - - files = result.stdout.split("\n") - - files = [file for file in files if file] - - clean_files = [] - - for temp in files: - if temp.endswith(".py"): - clean_files.append((temp.split("/")[-1]).strip(".py")) - - return clean_files - - -# Retry the git fetch and git diff commands in case of failure -retry = 0 -while retry < 3: - try: - CHANGED_FILES = get_invalid_numba_files() - break - except subprocess.CalledProcessError: - retry += 1 - -# If the retry count is reached, raise an error -if retry == 3: - raise Exception("Failed to get the changed files from git.") - - -def custom_load_index(self): - """Overwrite load index method for numba. - - This is used to overwrite the numba internal logic to allow for caching during - the cicd run. Numba traditionally checks the timestamp of the file and if it - has changed it invalidates the cache. This is not ideal for the cicd run as - the cache restore is always before the files (since it is cloned in) and - thus the cache is always invalidated. This custom method ignores the timestamp - element and instead just checks the file name. This isn't as fine grained as numba - but it is better to invalidate more and make sure the right function has been - compiled than try to be too clever and miss some. - - Returns - ------- - dict - Dictionary of the cached functions. - """ - try: - with open(self._index_path, "rb") as f: - version = pickle.load(f) - data = f.read() - except FileNotFoundError: - return {} - if version != self._version: - return {} - stamp, overloads = pickle.loads(data) - cache_filename = self._index_path.split("/")[-1].split("-")[0].split(".")[0] - if stamp[1] != self._source_stamp[1] or cache_filename in CHANGED_FILES: - return {} - else: - return overloads - - -original_load_index = numba.core.caching.IndexDataCacheFile._load_index -numba.core.caching.IndexDataCacheFile._load_index = custom_load_index - - -# Force all numba functions to be cached -original_jit = numba.core.decorators._jit - - -def custom_njit(*args, **kwargs): - """Force jit to cache. - - This is used for libraries like stumpy that doesn't cache by default. This - function will force all functions running to be cache'd - """ - target = kwargs["targetoptions"] - # This target can't be cached - if "no_cpython_wrapper" not in target: - kwargs["cache"] = True - return original_jit(*args, **kwargs) - - -# Overwrite the jit function with the custom version -numba.core.decorators._jit = custom_njit +if os.environ.get("CICD_RUNNING") == "1": # pragma: no cover + + def get_invalid_numba_files(): + """Get the files that have been changed since the last commit. + + This function is used to get the files that have been changed. This is needed + because custom logic to save the numba cache has been implemented for numba. + This function returns the file names that have been changed and if they appear + in here any numba functions cache are invalidated. + + Returns + ------- + list + List of file names that have been changed. + """ + subprocess.run(["git", "fetch", "origin", "main"], check=True) + + result = subprocess.run( + ["git", "diff", "--name-only", "origin/main"], + check=True, + capture_output=True, + text=True, # Makes the output text instead of bytes + ) + + files = result.stdout.split("\n") + + files = [file for file in files if file] + + clean_files = [] + + for temp in files: + if temp.endswith(".py"): + clean_files.append((temp.split("/")[-1]).strip(".py")) + + return clean_files + + # Retry the git fetch and git diff commands in case of failure + retry = 0 + while retry < 3: + try: + CHANGED_FILES = get_invalid_numba_files() + break + except subprocess.CalledProcessError: + retry += 1 + + # If the retry count is reached, raise an error + if retry == 3: + raise Exception("Failed to get the changed files from git.") + + def custom_load_index(self): + """Overwrite load index method for numba. + + This is used to overwrite the numba internal logic to allow for caching during + the cicd run. Numba traditionally checks the timestamp of the file and if it + has changed it invalidates the cache. This is not ideal for the cicd run as + the cache restore is always before the files (since it is cloned in) and + thus the cache is always invalidated. This custom method ignores the timestamp + element and instead just checks the file name. This isn't as fine grained as + numba but it is better to invalidate more and make sure the right function has + been compiled than try to be too clever and miss some. + + Returns + ------- + dict + Dictionary of the cached functions. + """ + try: + with open(self._index_path, "rb") as f: + version = pickle.load(f) + data = f.read() + except FileNotFoundError: + return {} + if version != self._version: + return {} + stamp, overloads = pickle.loads(data) + cache_filename = self._index_path.split("/")[-1].split("-")[0].split(".")[0] + if stamp[1] != self._source_stamp[1] or cache_filename in CHANGED_FILES: + return {} + else: + return overloads + + original_load_index = numba.core.caching.IndexDataCacheFile._load_index + numba.core.caching.IndexDataCacheFile._load_index = custom_load_index + + # Force all numba functions to be cached + original_jit = numba.core.decorators._jit + + def custom_njit(*args, **kwargs): + """Force jit to cache. + + This is used for libraries like stumpy that doesn't cache by default. This + function will force all functions running to be cache'd + """ + target = kwargs["targetoptions"] + # This target can't be cached + if "no_cpython_wrapper" not in target: + kwargs["cache"] = True + return original_jit(*args, **kwargs) + + # Overwrite the jit function with the custom version + numba.core.decorators._jit = custom_njit