Skip to content

Commit

Permalink
Fix: rename KerasLSTMBaseEstimator attributes to underscored prefixed
Browse files Browse the repository at this point in the history
  • Loading branch information
RollerKnobster committed Jun 25, 2024
1 parent 0ca65e3 commit 8432d16
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions gordo/machine/model/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import abc
import logging
import io
import importlib
import tempfile
from pprint import pformat
Expand All @@ -11,7 +10,6 @@
from copy import copy, deepcopy
from importlib.util import find_spec

import h5py
import tensorflow.keras.models
from tensorflow.keras.models import load_model, save_model
from tensorflow.keras.preprocessing.sequence import pad_sequences, TimeseriesGenerator
Expand Down Expand Up @@ -481,11 +479,11 @@ def __init__(
Any arguments which are passed to the factory building function and/or any
additional args to be passed to the intermediate fit method.
"""
self.lookback_window = lookback_window
self.batch_size = batch_size
kwargs["lookback_window"] = lookback_window
kwargs["kind"] = kind
kwargs["batch_size"] = batch_size
self._lookback_window = lookback_window
self._batch_size = batch_size
kwargs["_lookback_window"] = lookback_window
kwargs["_kind"] = kind
kwargs["_batch_size"] = batch_size

# fit_generator_params is a set of strings with the keyword arguments of
# Keras fit_generator method (excluding "shuffle" as this will be hardcoded).
Expand Down Expand Up @@ -535,7 +533,7 @@ def _validate_and_fix_size_of_X(self, X):
)
X = X.reshape(len(X), 1)

if self.lookback_window >= X.shape[0]:
if self._lookback_window >= X.shape[0]:
raise ValueError(
"For KerasLSTMForecast lookback_window must be < size of X"
)
Expand Down Expand Up @@ -571,11 +569,11 @@ def fit( # type: ignore
# model using the scikit-learn wrapper.
tsg = create_keras_timeseriesgenerator(
X=X[
: self.lookahead + self.lookback_window
: self.lookahead + self._lookback_window
], # We only need a bit of the data
y=y[: self.lookahead + self.lookback_window],
y=y[: self.lookahead + self._lookback_window],
batch_size=1,
lookback_window=self.lookback_window,
lookback_window=self._lookback_window,
lookahead=self.lookahead,
)

Expand All @@ -587,7 +585,7 @@ def fit( # type: ignore
X=X,
y=y,
batch_size=self.batch_size,
lookback_window=self.lookback_window,
lookback_window=self._lookback_window,
lookahead=self.lookahead,
)

Expand Down Expand Up @@ -640,7 +638,7 @@ def predict(self, X: np.ndarray, **kwargs) -> np.ndarray:
X=X,
y=X,
batch_size=10000,
lookback_window=self.lookback_window,
lookback_window=self._lookback_window,
lookahead=self.lookahead,
)
kwargs.setdefault("verbose", 0)
Expand Down

0 comments on commit 8432d16

Please sign in to comment.