Skip to content

Commit

Permalink
feature(Composites): made Pipeline class Tunable
Browse files Browse the repository at this point in the history
  • Loading branch information
almostintuitive committed Jun 28, 2023
1 parent 93498f6 commit fb832d3
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions src/fold/composites/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pandas as pd

from fold.base.classes import Artifact
from fold.base.classes import Artifact, Tunable

from ..base import Composite, Pipelines, Transformations, get_concatenated_names
from ..transformations.columns import SelectColumns
Expand Down Expand Up @@ -92,7 +92,7 @@ def clone(self, clone_children: Callable) -> Concat:
return clone


class Pipeline(Composite):
class Pipeline(Composite, Tunable):
"""
An optional wrappers that is equivalent to using a single array for the transformations.
It executes the transformations sequentially, in the order they are provided.
Expand All @@ -107,9 +107,15 @@ class Pipeline(Composite):

properties = Composite.Properties(primary_only_single_pipeline=True)

def __init__(self, pipeline: Pipeline, name: Optional[str] = None) -> None:
def __init__(
self,
pipeline: Pipeline,
name: Optional[str] = None,
params_to_try: Optional[dict] = None,
) -> None:
self.pipeline = wrap_in_double_list_if_needed(pipeline)
self.name = name or "Pipeline-" + get_concatenated_names(pipeline)
self.params_to_try = params_to_try

def postprocess_result_primary(
self, results: List[pd.DataFrame], y: Optional[pd.Series]
Expand All @@ -125,6 +131,19 @@ def clone(self, clone_children: Callable) -> Pipeline:
clone.name = self.name
return clone

def get_params(self) -> dict:
return dict(name=self.name)

def get_params_to_try(self) -> Optional[dict]:
return dict()

def clone_with_params(
self, parameters: dict, clone_children: Optional[Callable] = None
) -> Tunable:
instance = self.clone(self, clone_children)
instance.name = parameters["name"]
return instance


def TransformColumn(
columns: Union[List[str], str],
Expand Down

0 comments on commit fb832d3

Please sign in to comment.