Skip to content

Commit

Permalink
remove base deep network old class (#1817)
Browse files Browse the repository at this point in the history
  • Loading branch information
hadifawaz1999 authored Jul 18, 2024
1 parent 64e9eb8 commit 8d87b28
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 108 deletions.
2 changes: 1 addition & 1 deletion aeon/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@
from aeon.networks._mlp import MLPNetwork
from aeon.networks._resnet import ResNetNetwork
from aeon.networks._tapnet import TapNetNetwork
from aeon.networks.base import BaseDeepLearningNetwork, BaseDeepNetwork
from aeon.networks.base import BaseDeepLearningNetwork
41 changes: 0 additions & 41 deletions aeon/networks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,53 +4,12 @@

from abc import ABC, abstractmethod

from deprecated.sphinx import deprecated

from aeon.base import BaseObject
from aeon.utils.validation._dependencies import (
_check_estimator_deps,
_check_python_version,
_check_soft_dependencies,
)


# TODO: remove v0.11.0
@deprecated(
version="0.10.0",
reason="BaseDeepNetwork will be removed in 0.11.0, use BaseDeepLearningNetwork "
"instead. The new class does not inherit from BaseObject.",
category=FutureWarning,
)
class BaseDeepNetwork(BaseObject, ABC):
"""Abstract base class for deep learning networks."""

def __init__(self):
super().__init__()
_check_estimator_deps(self)

_config = {
"python_dependencies": ["tensorflow"],
"python_version": "<3.12",
"structure": "encoder",
}

@abstractmethod
def build_network(self, input_shape, **kwargs):
"""Construct a network and return its input and output layers.
Parameters
----------
input_shape : tuple
The shape of the data fed into the input layer
Returns
-------
input_layer : a keras layer
output_layer : a keras layer
"""
...


class BaseDeepLearningNetwork(ABC):
"""Abstract base class for deep learning networks."""

Expand Down
5 changes: 1 addition & 4 deletions aeon/networks/tests/test_all_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@ def test_all_networks_functionality(network):
"""Test the functionality of all networks."""
input_shape = (100, 2)

if not (
network.__name__
in ["BaseDeepNetwork", "BaseDeepLearningNetwork", "EncoderNetwork"]
):
if not (network.__name__ in ["BaseDeepLearningNetwork", "EncoderNetwork"]):
if _check_soft_dependencies(
network._config["python_dependencies"], severity="none"
) and _check_python_version(network._config["python_version"], severity="none"):
Expand Down
2 changes: 1 addition & 1 deletion docs/api_reference/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Deep learning networks
:toctree: auto_generated/
:template: class.rst

BaseDeepNetwork
BaseDeepLearningNetwork
CNNNetwork
EncoderNetwork
FCNNetwork
Expand Down
122 changes: 61 additions & 61 deletions examples/base/base_classes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
"cells": [
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"# Overview of the base class structure\n",
"\n",
Expand All @@ -10,13 +13,13 @@
"the following simplified UML\n",
"\n",
"<img src=\"img/uml.png\" width=\"800\" alt=\"Basic class hierarchy\">\n"
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## sklearn `BaseEstimator` and aeon `BaseObject`\n",
"\n",
Expand All @@ -29,106 +32,106 @@
"main functionality and may differ in details from the actual implementations.\n",
"\n",
"<img src=\"img/sklearn_base.png\" width=\"500\" alt=\"sklearn base\">\n"
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"The aeon class `BaseObject` extends `BaseEstimator` and adds the tagging method and\n",
"some other functionality used in `aeon` estimators\n",
"\n",
"<img src=\"img/base_object.png\" width=\"600\" alt=\"Base object\">\n"
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## aeons ``BaseEstimator``, `BaseDeepNetwork` and `BaseMetric`\n",
"## aeons ``BaseEstimator``, `BaseDeepLearningNetwork` and `BaseMetric`\n",
"\n",
"Three classes extend `BaseObject`: ``BaseEstimator``, `BaseDeepNetwork` and\n",
"Three classes extend `BaseObject`: ``BaseEstimator``, `BaseDeepLearningNetwork` and\n",
"`BaseMetric`.\n",
"\n",
"`BaseDeepNetwork` is the base class for all the deep learning networks defined in the\n",
"`BaseDeepLearningNetwork` is the base class for all the deep learning networks defined in the\n",
"`networks` module. It has a single abstract method `build_network`.\n",
"\n",
"<img src=\"img/base_deep_network.png\" width=\"600\" alt=\"Base object\">\n"
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"The `BaseMetric` class is the base class for forecasting performance metrics. It has\n",
"a single abstract method `evaluate`.\n",
"\n",
"<img src=\"img/base_metric.png\" width=\"600\" alt=\"Base metric\">\n"
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"The ``BaseEstimator`` class is the base class for the majority of classes in aeon.\n",
"Anything that uses fit and predict in aeon. It contains a protected attribute\n",
"`_is_fitted` and checks as to the value of this attribute. It also has a method to\n",
"get fitted parameters.\n",
"\n",
"<img src=\"img/base_estimator.png\" width=\"500\" alt=\"Base Estimator\">"
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"`BaseEstimator` has four direct base classes: `BaseForecaster`, `BaseTransformer` and `BaseCollectionEstimator`.\n",
"\n",
"\n",
"<img src=\"img/uml2.png\" width=\"600\" alt=\"Top level class hierarchy\">"
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## `BaseForecaster` (aeon.forecasting.base)\n",
"contains the forecasting specific methods. More details are available in the [API](https://www.aeon-toolkit.org/en/latest/api_reference/forecasting.html). `BaseForecaster` has the following concrete methods:\n",
"\n",
"<img src=\"img/base_forecaster.png\" width=\"700\" alt=\"Base forecaster\">"
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## `BaseTransformer` (aeon.transformations.base)\n",
"\n",
"Is the base class for all transformers, including single series transformers and\n",
"collections transformers.\n",
"\n",
"<img src=\"img/base_transformer.png\" width=\"700\" alt=\"Base transformer\">\n"
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## `BaseCollectionEstimator` (aeon.base)\n",
"\n",
Expand All @@ -139,13 +142,13 @@
"\n",
"<img src=\"img/base_collection_estimator.png\" width=\"700\" alt=\"Base\n",
"transformer\">"
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"The subclasses of `BaseCollectionEstimator` are as follows\n",
"\n",
Expand All @@ -163,13 +166,13 @@
" `BaseCollectionEstimator`.\n",
"\n",
"<img src=\"img/base_classifier.png\" width=\"700\" alt=\"Top level class hierarchy\">"
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## `BaseRegressor` (aeon.regression)\n",
"\n",
Expand All @@ -178,13 +181,13 @@
"\n",
"\n",
"<img src=\"img/base_regressor.png\" width=\"700\" alt=\"Top level class hierarchy\">"
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## `BaseClusterer` (aeon.clustering)\n",
"\n",
Expand All @@ -194,13 +197,13 @@
"\n",
"<img src=\"img/base_clusterer.png\" width=\"700\" alt=\"Base\n",
"transformer\">"
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## `BaseCollectionTransformer` (aeon.transformations.collection)\n",
"\n",
Expand All @@ -213,17 +216,14 @@
"\n",
"<img src=\"img/base_collection_transformer.png\" width=\"700\" alt=\"Base\n",
"transformer\">\n"
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "markdown",
"source": [],
"metadata": {
"collapsed": false
}
},
"source": []
}
],
"metadata": {
Expand Down

0 comments on commit 8d87b28

Please sign in to comment.