diff --git a/src/pymovements/gaze/gaze_dataframe.py b/src/pymovements/gaze/gaze_dataframe.py index 2e468e011..26cf9a453 100644 --- a/src/pymovements/gaze/gaze_dataframe.py +++ b/src/pymovements/gaze/gaze_dataframe.py @@ -194,26 +194,26 @@ def __init__( column_specifiers: list[list[str]] = [] if pixel_columns: - _check_component_columns(self.frame, pixel_columns=pixel_columns) + self._check_component_columns(pixel_columns=pixel_columns) self.nest(pixel_columns, output_column='pixel') column_specifiers.append(pixel_columns) if position_columns: - _check_component_columns(self.frame, position_columns=position_columns) + self._check_component_columns(position_columns=position_columns) self.nest(position_columns, output_column='position') column_specifiers.append(position_columns) if velocity_columns: - _check_component_columns(self.frame, velocity_columns=velocity_columns) + self._check_component_columns(velocity_columns=velocity_columns) self.nest(velocity_columns, output_column='velocity') column_specifiers.append(velocity_columns) if acceleration_columns: - _check_component_columns(self.frame, acceleration_columns=acceleration_columns) + self._check_component_columns(acceleration_columns=acceleration_columns) self.nest(acceleration_columns, output_column='acceleration') column_specifiers.append(acceleration_columns) - self.n_components = _infer_n_components(self.frame, column_specifiers) + self.n_components = self._infer_n_components(column_specifiers) self.experiment = experiment self.events = pm.EventDataFrame() @@ -266,7 +266,7 @@ def transform( kwargs['sampling_rate'] = self.experiment.sampling_rate if 'n_components' in method_kwargs and 'n_components' not in kwargs: - _check_n_components(self.n_components) + self._check_n_components() kwargs['n_components'] = self.n_components if self.trial_columns is None: @@ -380,46 +380,8 @@ def detect( if isinstance(method, str): method = pm.events.EventDetectionLibrary.get(method) - # Automatically infer eye to use for event detection. - _check_n_components(self.n_components) - if eye == 'auto': - if self.n_components == 6: - eye_components = 4, 5 - elif self.n_components == 4: - eye_components = 2, 3 - elif self.n_components == 2: - eye_components = 0, 1 - else: - raise AttributeError() - - method_args = inspect.getfullargspec(method).args - - if 'positions' in method_args: - positions = np.vstack( - [ - self.frame.get_column('position').list.get(eye_component) - for eye_component in eye_components - ], - ).transpose() - kwargs['positions'] = positions - - if 'velocities' in method_args: - velocities = np.vstack( - [ - self.frame.get_column('velocity').list.get(eye_component) - for eye_component in eye_components - ], - ).transpose() - kwargs['velocities'] = velocities - - if 'events' in method_args: - kwargs['events'] = self.events - - if 'timesteps' in method_args and 'time' in self.frame.columns: - timesteps = self.frame.get_column('time').to_numpy() - kwargs['timesteps'] = timesteps - - new_events = method(**kwargs) + method_kwargs = self._fill_event_detection_kwargs(method, eye, **kwargs) + new_events = method(**method_kwargs) self.events.frame = pl.concat( [self.events.frame, new_events.frame], @@ -452,7 +414,7 @@ def nest( output_column: Name of the resulting tuple column. """ - _check_component_columns(frame=self.frame, **{output_column: input_columns}) + self._check_component_columns(**{output_column: input_columns}) self.frame = self.frame.with_columns( pl.concat_list([pl.col(component) for component in input_columns]) @@ -492,7 +454,7 @@ def unnest( output_columns=output_columns, output_suffixes=output_suffixes, ) - _check_n_components(self.n_components) + self._check_n_components() col_names = output_columns if output_columns is not None else [] @@ -542,89 +504,154 @@ def _check_experiment(self) -> None: if self.experiment is None: raise AttributeError('experiment must not be None for this method to work') - -def _check_component_columns( - frame: pl.DataFrame, - **kwargs: list[str], -) -> None: - """Check if component columns are in valid format.""" - for component_type, columns in kwargs.items(): - if not isinstance(columns, list): - raise TypeError( - f'{component_type} must be of type list, but is of type {type(columns).__name__}', + def _check_n_components(self) -> None: + """Check that n_components is either 2, 4 or 6.""" + if self.n_components not in {2, 4, 6}: + raise AttributeError( + f'n_components must be either 2, 4 or 6 but is {self.n_components}', ) - for column in columns: - if not isinstance(column, str): + def _check_component_columns(self, **kwargs: list[str]) -> None: + """Check if component columns are in valid format.""" + for component_type, columns in kwargs.items(): + if not isinstance(columns, list): raise TypeError( - f'all elements in {component_type} must be of type str, ' - f'but one of the elements is of type {type(column).__name__}', + f'{component_type} must be of type list, ' + f'but is of type {type(columns).__name__}', ) - if len(columns) not in [2, 4, 6]: - raise ValueError( - f'{component_type} must contain either 2, 4 or 6 columns, but has {len(columns)}', - ) + for column in columns: + if not isinstance(column, str): + raise TypeError( + f'all elements in {component_type} must be of type str, ' + f'but one of the elements is of type {type(column).__name__}', + ) + + if len(columns) not in [2, 4, 6]: + raise ValueError( + f'{component_type} must contain either 2, 4 or 6 columns, ' + f'but has {len(columns)}', + ) - for column in columns: - if column not in frame.columns: - raise pl.exceptions.ColumnNotFoundError( - f'column {column} from {component_type} is not available in dataframe', + for column in columns: + if column not in self.frame.columns: + raise pl.exceptions.ColumnNotFoundError( + f'column {column} from {component_type} is not available in dataframe', + ) + + if len(set(self.frame[columns].dtypes)) != 1: + types_list = sorted([str(t) for t in set(self.frame[columns].dtypes)]) + raise ValueError( + f'all columns in {component_type} must be of same type, ' + f'but types are {types_list}', ) - if len(set(frame[columns].dtypes)) != 1: - types_list = sorted([str(t) for t in set(frame[columns].dtypes)]) - raise ValueError( - f'all columns in {component_type} must be of same type, but types are {types_list}', - ) + def _infer_n_components(self, column_specifiers: list[list[str]]) -> int | None: + """Infer number of components from DataFrame. + + Method checks nested columns `pixel`, `position`, `velocity` and `acceleration` for number + of components by getting their list lenghts, which must be equal for all else a ValueError + is raised. Additionally, a list of list of column specifiers is checked for consistency. + + Parameters + ---------- + frame: pl.DataFrame + DataFrame to check. + column_specifiers: + List of list of column specifiers. + Returns + ------- + int or None + Number of components + + Raises + ------ + ValueError + If number of components is not equal for all considered columns and rows. + """ + all_considered_columns = ['pixel', 'position', 'velocity', 'acceleration'] + considered_columns = [ + column for column in all_considered_columns if column in self.frame.columns + ] -def _check_n_components(n_components: Any) -> None: - """Check that n_components is either 2, 4 or 6.""" - if n_components not in {2, 4, 6}: - raise AttributeError(f'n_components must be either 2, 4 or 6 but is {n_components}') + list_lengths = { + list_length + for column in considered_columns + for list_length in self.frame.get_column(column).list.lengths().unique().to_list() + } + for column_specifier_list in column_specifiers: + list_lengths.add(len(column_specifier_list)) -def _infer_n_components(frame: pl.DataFrame, column_specifiers: list[list[str]]) -> int | None: - """Infer number of components from DataFrame. + if len(list_lengths) > 1: + raise ValueError(f'inconsistent number of components inferred: {list_lengths}') - Method checks nested columns `pixel`, `position`, `velocity` and `acceleration` for number of - components by getting their list lenghts, which must be equal for all else a ValueError is - raised. Additionally, a list of list of column specifiers is checked for consistency. + if len(list_lengths) == 0: + return None - Parameters - ---------- - frame: pl.DataFrame - DataFrame to check. - column_specifiers: - List of list of column specifiers. + return next(iter(list_lengths)) - Returns - ------- - int or None - Number of components + def _infer_eye_components(self, eye: str) -> tuple[int, int]: + self._check_n_components() - Raises - ------ - ValueError - If number of components is not equal for all considered columns and rows. - """ - all_considered_columns = ['pixel', 'position', 'velocity', 'acceleration'] - considered_columns = [column for column in all_considered_columns if column in frame.columns] + if eye == 'auto': + # Order of inference: cyclops, right, left + if self.n_components == 6: + eye_components = 4, 5 + elif self.n_components == 4: + eye_components = 2, 3 + else: # We already checked validity, must be 2. + eye_components = 0, 1 + elif eye == 'left': + if self.n_components < 4: # Left only makes sense if there are at least two eyes. + raise AttributeError() + eye_components = 0, 1 + elif eye == 'right': + if self.n_components < 4: # Right only makes sense if there are at least two eyes. + raise AttributeError() + eye_components = 2, 3 + elif eye == 'cyclops': + if self.n_components < 6: + raise AttributeError() + eye_components = 4, 5 + elif eye == 'mono': + eye_components = 0, 1 + else: + raise ValueError() + + return eye_components - list_lengths = { - list_length - for column in considered_columns - for list_length in frame.get_column(column).list.lengths().unique().to_list() - } + def _fill_event_detection_kwargs( + self, + method: Callable[..., pm.EventDataFrame], + eye: str, + **kwargs, + ): + # Automatically infer eye to use for event detection. + eye_components = self._infer_eye_components(eye) + method_args = inspect.getfullargspec(method).args + + if 'positions' in method_args: + kwargs['positions'] = np.vstack( + [ + self.frame.get_column('position').list.get(eye_component) + for eye_component in eye_components + ], + ).transpose() - for column_specifier_list in column_specifiers: - list_lengths.add(len(column_specifier_list)) + if 'velocities' in method_args: + kwargs['velocities'] = np.vstack( + [ + self.frame.get_column('velocity').list.get(eye_component) + for eye_component in eye_components + ], + ).transpose() - if len(list_lengths) > 1: - raise ValueError(f'inconsistent number of components inferred: {list_lengths}') + if 'events' in method_args: + kwargs['events'] = self.events - if len(list_lengths) == 0: - return None + if 'timesteps' in method_args and 'time' in self.frame.columns: + kwargs['timesteps'] = self.frame.get_column('time').to_numpy() - return next(iter(list_lengths)) + return kwargs diff --git a/tests/events/detection/ivt_test.py b/tests/events/detection/ivt_test.py index 0aaea5173..e5377db55 100644 --- a/tests/events/detection/ivt_test.py +++ b/tests/events/detection/ivt_test.py @@ -219,8 +219,6 @@ def test_ivt_detects_fixations(kwargs, expected): kwargs['positions'], sampling_rate=10, method='preceding', ) - print(velocities) - assert False # Just use positions argument for velocity calculation kwargs.pop('positions')