Skip to content

Commit

Permalink
move split method to gaze dataframe
Browse files Browse the repository at this point in the history
  • Loading branch information
SiQube committed Nov 17, 2024
1 parent 21fd0d2 commit 0856658
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 24 deletions.
40 changes: 16 additions & 24 deletions src/pymovements/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,35 +235,27 @@ def _split_gaze_data(
self,
by: list[str] | str,
) -> None:
"""Split gaze data into seperated GazeDataFrame's.
"""Split gaze data into separated GazeDataFrame's.
Parameters
----------
by: list[str] | str
Column's to split dataframe by.
Column(s) to split dataframe by.
"""
if isinstance(by, str):
by = [by]
new_data = [
(
GazeDataFrame(
new_frame,
experiment=_frame.experiment,
trial_columns=self.definition.trial_columns,
time_column=self.definition.time_column,
time_unit=self.definition.time_unit,
position_columns=self.definition.position_columns,
velocity_columns=self.definition.velocity_columns,
acceleration_columns=self.definition.acceleration_columns,
distance_column=self.definition.distance_column,
),
fileinfo_row,
)
for (_frame, fileinfo_row) in zip(self.gaze, self.fileinfo['gaze'].to_dicts())
for new_frame in _frame.frame.partition_by(by=by)
]
self.gaze = [data[0] for data in new_data]
self.fileinfo['gaze'] = pl.concat([pl.from_dict(data[1]) for data in new_data])
by = [by] if isinstance(by, str) else by

fileinfo_dicts = self.fileinfo['gaze'].to_dicts()

all_gaze_frames = []
all_fileinfo_rows = []

for frame, fileinfo_row in zip(self.gaze, fileinfo_dicts):
split_frames = frame.split(by=by)
all_gaze_frames.extend(split_frames)
all_fileinfo_rows.extend([fileinfo_row] * len(split_frames))

self.gaze = all_gaze_frames
self.fileinfo['gaze'] = pl.concat([pl.from_dict(row) for row in all_fileinfo_rows])

def split_precomputed_events(
self,
Expand Down
41 changes: 41 additions & 0 deletions src/pymovements/gaze/gaze_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,14 @@ def __init__(

# Remove this attribute once #893 is fixed
self._metadata: dict[str, Any] | None = None
self.auto_column_detect = auto_column_detect
self.time_column = time_column
self.time_unit = time_unit
self.pixel_columns = pixel_columns
self.position_columns = position_columns
self.velocity_columns = velocity_columns
self.acceleration_columns = acceleration_columns
self.distance_column = distance_column

def apply(
self,
Expand All @@ -307,6 +315,39 @@ def apply(
else:
raise ValueError(f"unsupported method '{function}'")

def split(self, by: list[str] | str) -> list[GazeDataFrame]:
"""Split the GazeDataFrame into multiple frames based on specified column(s).
Parameters
----------
by: list[str] | str
Column name(s) to split the DataFrame by. If a single string is provided,
it will be used as a single column name. If a list is provided, the DataFrame
will be split by unique combinations of values in all specified columns.
Returns
-------
list[GazeDataFrame]
A list of new GazeDataFrame instances, each containing a partition of the
original data with all metadata and configurations preserved.
"""
by = [by] if isinstance(by, str) else by
return [
GazeDataFrame(
new_frame,
experiment=self.experiment,
auto_column_detect=self.auto_column_detect,
trial_columns=self.trial_columns,
time_column=self.time_column,
time_unit=self.time_unit,
position_columns=self.position_columns,
velocity_columns=self.velocity_columns,
acceleration_columns=self.acceleration_columns,
distance_column=self.distance_column,
)
for new_frame in self.frame.partition_by(by=by)
]

def transform(
self,
transform_method: str | Callable[..., pl.Expr],
Expand Down

0 comments on commit 0856658

Please sign in to comment.