Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Infer correct number of components in gaze init #521

Merged
merged 5 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 64 additions & 41 deletions src/pymovements/gaze/gaze_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,55 +187,31 @@ def __init__(
if time_column is not None:
self.frame = self.frame.rename({time_column: 'time'})

n_components = None
# List of passed not-None column specifier lists.
# The list will be used for inferring n_components.
column_specifiers: list[list[str]] = []

if pixel_columns:
_check_component_columns(
frame=self.frame,
pixel_columns=pixel_columns,
)
self.nest(
input_columns=pixel_columns,
output_column='pixel',
)
n_components = len(pixel_columns)
_check_component_columns(self.frame, pixel_columns=pixel_columns)
self.nest(pixel_columns, output_column='pixel')
column_specifiers.append(pixel_columns)

if position_columns:
_check_component_columns(
frame=self.frame,
position_columns=position_columns,
)

self.nest(
input_columns=position_columns,
output_column='position',
)
n_components = len(position_columns)
_check_component_columns(self.frame, position_columns=position_columns)
self.nest(position_columns, output_column='position')
column_specifiers.append(position_columns)

if velocity_columns:
_check_component_columns(
frame=self.frame,
velocity_columns=velocity_columns,
)

self.nest(
input_columns=velocity_columns,
output_column='velocity',
)
n_components = len(velocity_columns)
_check_component_columns(self.frame, velocity_columns=velocity_columns)
self.nest(velocity_columns, output_column='velocity')
column_specifiers.append(velocity_columns)

if acceleration_columns:
_check_component_columns(
frame=self.frame,
acceleration_columns=acceleration_columns,
)
_check_component_columns(self.frame, acceleration_columns=acceleration_columns)
self.nest(acceleration_columns, output_column='acceleration')
column_specifiers.append(acceleration_columns)

self.nest(
input_columns=acceleration_columns,
output_column='acceleration',
)
n_components = len(acceleration_columns)

self.n_components = n_components
self.n_components = _infer_n_components(self.frame, column_specifiers)
self.experiment = experiment

def transform(
Expand Down Expand Up @@ -397,6 +373,8 @@ def nest(
output_column:
Name of the resulting tuple column.
"""
_check_component_columns(frame=self.frame, **{output_column: input_columns})

self.frame = self.frame.with_columns(
pl.concat_list([pl.col(component) for component in input_columns])
.alias(output_column),
Expand Down Expand Up @@ -526,3 +504,48 @@ def _check_n_components(n_components: Any) -> None:
"""Check that n_components is either 2, 4 or 6."""
if n_components not in {2, 4, 6}:
raise AttributeError(f'n_components must be either 2, 4 or 6 but is {n_components}')


def _infer_n_components(frame: pl.DataFrame, column_specifiers: list[list[str]]) -> int | None:
"""Infer number of components from DataFrame.

Method checks nested columns `pixel`, `position`, `velocity` and `acceleration` for number of
components by getting their list lenghts, which must be equal for all else a ValueError is
raised. Additionally, a list of list of column specifiers is checked for consistency.

Parameters
----------
frame: pl.DataFrame
DataFrame to check.
column_specifiers:
List of list of column specifiers.

Returns
-------
int or None
Number of components

Raises
------
ValueError
If number of components is not equal for all considered columns and rows.
"""
all_considered_columns = ['pixel', 'position', 'velocity', 'acceleration']
dkrako marked this conversation as resolved.
Show resolved Hide resolved
considered_columns = [column for column in all_considered_columns if column in frame.columns]

list_lengths = {
list_length
for column in considered_columns
for list_length in frame.get_column(column).list.lengths().unique().to_list()
}

for column_specifier_list in column_specifiers:
list_lengths.add(len(column_specifier_list))

if len(list_lengths) > 1:
raise ValueError(f'inconsistent number of components inferred: {list_lengths}')

if len(list_lengths) == 0:
return None

return next(iter(list_lengths))
Loading