Skip to content

Commit

Permalink
Merge branch 'sktime:main' into estimator-overview-detector
Browse files Browse the repository at this point in the history
  • Loading branch information
gavinkatz001 authored Nov 20, 2024
2 parents 7ca6225 + d7f5823 commit aee26c2
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 11 deletions.
9 changes: 9 additions & 0 deletions .all-contributorsrc
Original file line number Diff line number Diff line change
Expand Up @@ -3191,6 +3191,15 @@
"contributions": [
"doc"
]
},
{
"login": "mjste",
"name": "Michał Stefanik",
"avatar_url": "https://avatars.githubusercontent.com/mjste",
"profile": "https://github.com/mjste",
"contributions": [
"doc"
]
}
]
}
4 changes: 2 additions & 2 deletions examples/01c_forecasting_hierarchical_global.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@
"* instance index: the first element of pairs in `obj.index` is interpreted as an instance index. \n",
"* variables: columns of `obj` correspond to different variables\n",
"* variable names: column names `obj.columns`\n",
"* time points: rows of `obj` with the same `\"timepoints\"` index correspond correspond to the same time point; rows of `obj` with different `\"timepoints\"` index correspond correspond to the different time points.\n",
"* time points: rows of `obj` with the same `\"timepoints\"` index correspond to the same time point; rows of `obj` with different `\"timepoints\"` index correspond to the different time points.\n",
"* time index: the second element of pairs in `obj.index` is interpreted as a time index. \n",
"* capabilities: can represent panels of multivariate series; can represent unequally spaced series; can represent panels of unequally supported series; cannot represent panels of series with different sets of variables."
]
Expand Down Expand Up @@ -416,7 +416,7 @@
"* hierarchy: the non-time-like indices in `obj.index` are interpreted as a hierarchy identifying index. \n",
"* variables: columns of `obj` correspond to different variables\n",
"* variable names: column names `obj.columns`\n",
"* time points: rows of `obj` with the same `\"timepoints\"` index correspond correspond to the same time point; rows of `obj` with different `\"timepoints\"` index correspond correspond to the different time points.\n",
"* time points: rows of `obj` with the same `\"timepoints\"` index correspond to the same time point; rows of `obj` with different `\"timepoints\"` index correspond to the different time points.\n",
"* time index: the last element of tuples in `obj.index` is interpreted as a time index. \n",
"* capabilities: can represent hierarchical series; can represent unequally spaced series; can represent unequally supported hierarchical series; cannot represent hierarchical series with different sets of variables."
]
Expand Down
4 changes: 2 additions & 2 deletions examples/AA_datatypes_and_datasets.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@
"* instance index: the first element of pairs in `obj.index` (0-th level value) is interpreted as an instance index, we call it \"instance index\" below.\n",
"* instances: rows with the same \"instance index\" index value correspond to the same instance; rows with different \"instance index\" values correspond to different instances. \n",
"* time index: the second element of pairs in `obj.index` (1-st level value) is interpreted as a time index, we call it \"time index\" below. \n",
"* time points: rows of `obj` with the same \"time index\" value correspond correspond to the same time point; rows of `obj` with different \"time index\" index correspond correspond to the different time points.\n",
"* time points: rows of `obj` with the same \"time index\" value correspond to the same time point; rows of `obj` with different \"time index\" index correspond to the different time points.\n",
"* variables: columns of `obj` correspond to different variables\n",
"* variable names: column names `obj.columns`\n",
"* capabilities: can represent panels of multivariate series; can represent unequally spaced series; can represent panels of unequally supported series; cannot represent panels of series with different sets of variables."
Expand Down Expand Up @@ -759,7 +759,7 @@
"* hierarchy level: rows with the same non-time-like index values correspond to the same hierarchy unit; rows with different non-time-like index combination correspond to different hierarchy unit.\n",
"* hierarchy: the non-time-like indices in `obj.index` are interpreted as a hierarchy identifying index. \n",
"* time index: the last element of tuples in `obj.index` is interpreted as a time index. \n",
"* time points: rows of `obj` with the same `\"timepoints\"` index correspond correspond to the same time point; rows of `obj` with different `\"timepoints\"` index correspond correspond to the different time points.\n",
"* time points: rows of `obj` with the same `\"timepoints\"` index correspond to the same time point; rows of `obj` with different `\"timepoints\"` index correspond to the different time points.\n",
"* variables: columns of `obj` correspond to different variables\n",
"* variable names: column names `obj.columns`\n",
"* capabilities: can represent hierarchical series; can represent unequally spaced series; can represent unequally supported hierarchical series; cannot represent hierarchical series with different sets of variables."
Expand Down
61 changes: 57 additions & 4 deletions sktime/annotation/tests/test_all_annotators.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@ class AnnotatorsFixtureGenerator(BaseFixtureGenerator):

estimator_type_filter = "detector"

fixture_sequence = [
"estimator_class",
"estimator_instance",
"fitted_estimator",
"scenario",
"method_nsc",
"method_nsc_arraylike",
]


class TestAllAnnotators(AnnotatorsFixtureGenerator, QuickTester):
"""Module level tests for all sktime annotators."""
Expand All @@ -39,16 +48,60 @@ def test_output_type(self, estimator_instance):
"""Test annotator output type."""
estimator = estimator_instance

arg = make_annotation_problem(
X_train = make_annotation_problem(
n_timepoints=50, estimator_type=estimator.get_tag("distribution_type")
)
estimator.fit(arg)
arg = make_annotation_problem(
estimator.fit(X_train)
X_test = make_annotation_problem(
n_timepoints=10, estimator_type=estimator.get_tag("distribution_type")
)
y_pred = estimator.predict(arg)
y_test = estimator.predict(X_test)
assert isinstance(y_test, (pd.Series, np.ndarray))

def test_transform_output_type(self, estimator_instance):
"""Test output type for the transform method."""
X_train = make_annotation_problem(
n_timepoints=50,
estimator_type=estimator_instance.get_tag("distribution_type"),
)
estimator_instance.fit(X_train)
X_test = make_annotation_problem(
n_timepoints=10,
estimator_type=estimator_instance.get_tag("distribution_type"),
)
y_test = estimator_instance.transform(X_test)
assert isinstance(y_test, (pd.Series, np.ndarray))
assert len(y_test) == len(X_test)

def test_predict_points(self, estimator_instance):
X_train = make_annotation_problem(
n_timepoints=50,
estimator_type=estimator_instance.get_tag("distribution_type"),
)
estimator_instance.fit(X_train)
X_test = make_annotation_problem(
n_timepoints=10,
estimator_type=estimator_instance.get_tag("distribution_type"),
)
y_pred = estimator_instance.predict_points(X_test)
assert isinstance(y_pred, (pd.Series, np.ndarray))

def test_predict_segments(self, estimator_instance):
X_train = make_annotation_problem(
n_timepoints=50,
estimator_type=estimator_instance.get_tag("distribution_type"),
)
estimator_instance.fit(X_train)

X_test = make_annotation_problem(
n_timepoints=10,
estimator_type=estimator_instance.get_tag("distribution_type"),
)
y_test = estimator_instance.predict_segments(X_test)
assert isinstance(y_test, pd.Series)
assert isinstance(y_test.index.dtype, pd.IntervalDtype)
assert pd.api.types.is_integer_dtype(y_test)

def test_annotator_tags(self, estimator_class):
"""Check the learning_type and task tags are valid."""
check_task(estimator_class.get_class_tag("task"))
Expand Down
2 changes: 1 addition & 1 deletion sktime/datatypes/_panel/_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ class PanelPdMultiIndex(ScitypePanel):
is interpreted as a time index, we call it "time index" below.
* time points: rows of ``obj`` with the same "time index" value correspond
correspond to the same time point; rows of `obj` with different "time index"
index correspond correspond to the different time points.
index correspond to the different time points.
* variables: columns of ``obj`` correspond to different variables
* variable names: column names ``obj.columns``
Expand Down
1 change: 1 addition & 0 deletions sktime/forecasting/croston.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class Croston(BaseForecaster):
# estimator type
# --------------
"requires-fh-in-fit": False, # is forecasting horizon already required in fit?
"ignores-exogeneous-X": True,
}

def __init__(self, smoothing=0.1):
Expand Down
58 changes: 56 additions & 2 deletions sktime/tests/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,6 @@
"test_update_predict_single", # see 2997, sporadic failure, unknown cause
"test__y_when_refitting", # see 3176
],
# GGS inherits from BaseEstimator which breaks this test
"GreedyGaussianSegmentation": ["test_inheritance", "test_create_test_instance"],
"InformationGainSegmentation": [
"test_inheritance",
"test_create_test_instance",
Expand Down Expand Up @@ -248,6 +246,62 @@
"test_persistence_via_pickle",
"test_save_estimators_to_file",
],
# The following detectors are not interface compliant. See PR 6958
"PoissonHMM": [
"test_predict_points",
"test_predict_segments",
"test_transform_output_type",
],
"HMM": [
"test_predict_points",
"test_predict_segments",
"test_transform_output_type",
],
"ClaSPSegmentation": [
"test_predict_points",
"test_predict_segments",
"test_transform_output_type",
],
"ClusterSegmenter": [
"test_predict_points",
"test_predict_segments",
"test_transform_output_type",
],
"AnnotatorPipeline": [
"test_predict_points",
"test_predict_segments",
"test_transform_output_type",
],
"BinarySegmentation": [
"test_predict_segments",
],
"GreedyGaussianSegmentation": [
"test_predict_points",
"test_predict_segments",
"test_transform_output_type",
"test_inheritance",
"test_create_test_instance",
],
"PyODAnnotator": [
"test_predict_points",
"test_predict_segments",
"test_transform_output_type",
],
"GaussianHMM": [
"test_predict_points",
"test_predict_segments",
"test_transform_output_type",
],
"GMMHMM": [
"test_predict_points",
"test_predict_segments",
"test_transform_output_type",
],
"SubLOF": [
"test_predict_points",
"test_predict_segments",
"test_transform_output_type",
],
}

# exclude tests but keyed by test name
Expand Down
1 change: 1 addition & 0 deletions sktime/transformations/series/holiday/_holidayfeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class HolidayFeatures(BaseTransformer):
Returns country holiday features with custom holiday windows
>>> from sktime.transformations.series.holiday import HolidayFeatures
>>> transformer = HolidayFeatures(
... calendar=country_holidays(country="FR"),
... return_categorical=True,
Expand Down

0 comments on commit aee26c2

Please sign in to comment.