Skip to content

Commit

Permalink
fix: Infer correct number of components in gaze init (#521)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkrako authored Sep 12, 2023
1 parent 2e59d0c commit 2f7b54c
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 89 deletions.
105 changes: 64 additions & 41 deletions src/pymovements/gaze/gaze_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,55 +187,31 @@ def __init__(
if time_column is not None:
self.frame = self.frame.rename({time_column: 'time'})

n_components = None
# List of passed not-None column specifier lists.
# The list will be used for inferring n_components.
column_specifiers: list[list[str]] = []

if pixel_columns:
_check_component_columns(
frame=self.frame,
pixel_columns=pixel_columns,
)
self.nest(
input_columns=pixel_columns,
output_column='pixel',
)
n_components = len(pixel_columns)
_check_component_columns(self.frame, pixel_columns=pixel_columns)
self.nest(pixel_columns, output_column='pixel')
column_specifiers.append(pixel_columns)

if position_columns:
_check_component_columns(
frame=self.frame,
position_columns=position_columns,
)

self.nest(
input_columns=position_columns,
output_column='position',
)
n_components = len(position_columns)
_check_component_columns(self.frame, position_columns=position_columns)
self.nest(position_columns, output_column='position')
column_specifiers.append(position_columns)

if velocity_columns:
_check_component_columns(
frame=self.frame,
velocity_columns=velocity_columns,
)

self.nest(
input_columns=velocity_columns,
output_column='velocity',
)
n_components = len(velocity_columns)
_check_component_columns(self.frame, velocity_columns=velocity_columns)
self.nest(velocity_columns, output_column='velocity')
column_specifiers.append(velocity_columns)

if acceleration_columns:
_check_component_columns(
frame=self.frame,
acceleration_columns=acceleration_columns,
)
_check_component_columns(self.frame, acceleration_columns=acceleration_columns)
self.nest(acceleration_columns, output_column='acceleration')
column_specifiers.append(acceleration_columns)

self.nest(
input_columns=acceleration_columns,
output_column='acceleration',
)
n_components = len(acceleration_columns)

self.n_components = n_components
self.n_components = _infer_n_components(self.frame, column_specifiers)
self.experiment = experiment

def transform(
Expand Down Expand Up @@ -397,6 +373,8 @@ def nest(
output_column:
Name of the resulting tuple column.
"""
_check_component_columns(frame=self.frame, **{output_column: input_columns})

self.frame = self.frame.with_columns(
pl.concat_list([pl.col(component) for component in input_columns])
.alias(output_column),
Expand Down Expand Up @@ -526,3 +504,48 @@ def _check_n_components(n_components: Any) -> None:
"""Check that n_components is either 2, 4 or 6."""
if n_components not in {2, 4, 6}:
raise AttributeError(f'n_components must be either 2, 4 or 6 but is {n_components}')


def _infer_n_components(frame: pl.DataFrame, column_specifiers: list[list[str]]) -> int | None:
"""Infer number of components from DataFrame.
Method checks nested columns `pixel`, `position`, `velocity` and `acceleration` for number of
components by getting their list lenghts, which must be equal for all else a ValueError is
raised. Additionally, a list of list of column specifiers is checked for consistency.
Parameters
----------
frame: pl.DataFrame
DataFrame to check.
column_specifiers:
List of list of column specifiers.
Returns
-------
int or None
Number of components
Raises
------
ValueError
If number of components is not equal for all considered columns and rows.
"""
all_considered_columns = ['pixel', 'position', 'velocity', 'acceleration']
considered_columns = [column for column in all_considered_columns if column in frame.columns]

list_lengths = {
list_length
for column in considered_columns
for list_length in frame.get_column(column).list.lengths().unique().to_list()
}

for column_specifier_list in column_specifiers:
list_lengths.add(len(column_specifier_list))

if len(list_lengths) > 1:
raise ValueError(f'inconsistent number of components inferred: {list_lengths}')

if len(list_lengths) == 0:
return None

return next(iter(list_lengths))
Loading

0 comments on commit 2f7b54c

Please sign in to comment.