Skip to content

Commit

Permalink
list key for sequentail features
Browse files Browse the repository at this point in the history
  • Loading branch information
MDobransky committed Sep 27, 2024
1 parent 79d51ad commit b5b15ae
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 3 deletions.
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
project = "rialto"
copyright = "2022, Marek Dobransky"
author = "Marek Dobransky"
release = "1.3.0"
release = "2.0.1"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
7 changes: 5 additions & 2 deletions rialto/maker/feature_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ def _set_values(self, df: DataFrame, key: typing.Union[str, typing.List[str]], m
:return: None
"""
self.data_frame = df
self.key = key
if isinstance(key, str):
self.key = [key]
else:
self.key = key
self.make_date = make_date

def _order_by_dependencies(self, feature_holders: typing.List[FeatureHolder]) -> typing.List[FeatureHolder]:
Expand Down Expand Up @@ -136,7 +139,7 @@ def _make_sequential(self, keep_preexisting: bool) -> DataFrame:
)
if not keep_preexisting:
logger.info("Dropping non-selected columns")
self.data_frame = self.data_frame.select(self.key, *feature_names)
self.data_frame = self.data_frame.select(*self.key, *feature_names)
return self._filter_null_keys(self.data_frame)

def _make_aggregated(self) -> DataFrame:
Expand Down
7 changes: 7 additions & 0 deletions tests/maker/test_FeatureMaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ def test_sequential_multi_key(input_df):
assert "TRANSACTIONS_OUTBOUND_VALUE" in df.columns


def test_sequential_multi_key_drop(input_df):
df, _ = FeatureMaker.make(
input_df, ["CUSTOMER_KEY", "TYPE"], date.today(), sequential_outbound, keep_preexisting=False
)
assert "TRANSACTIONS_OUTBOUND_VALUE" in df.columns


def test_sequential_keeps(input_df):
df, _ = FeatureMaker.make(input_df, "CUSTOMER_KEY", date.today(), sequential_outbound, keep_preexisting=True)
assert "AMT" in df.columns
Expand Down

0 comments on commit b5b15ae

Please sign in to comment.