Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
dkrako committed Sep 13, 2023
1 parent 5ee2320 commit b8f6db2
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 115 deletions.
253 changes: 140 additions & 113 deletions src/pymovements/gaze/gaze_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,26 +194,26 @@ def __init__(
column_specifiers: list[list[str]] = []

if pixel_columns:
_check_component_columns(self.frame, pixel_columns=pixel_columns)
self._check_component_columns(pixel_columns=pixel_columns)
self.nest(pixel_columns, output_column='pixel')
column_specifiers.append(pixel_columns)

if position_columns:
_check_component_columns(self.frame, position_columns=position_columns)
self._check_component_columns(position_columns=position_columns)
self.nest(position_columns, output_column='position')
column_specifiers.append(position_columns)

if velocity_columns:
_check_component_columns(self.frame, velocity_columns=velocity_columns)
self._check_component_columns(velocity_columns=velocity_columns)
self.nest(velocity_columns, output_column='velocity')
column_specifiers.append(velocity_columns)

if acceleration_columns:
_check_component_columns(self.frame, acceleration_columns=acceleration_columns)
self._check_component_columns(acceleration_columns=acceleration_columns)
self.nest(acceleration_columns, output_column='acceleration')
column_specifiers.append(acceleration_columns)

self.n_components = _infer_n_components(self.frame, column_specifiers)
self.n_components = self._infer_n_components(column_specifiers)
self.experiment = experiment
self.events = pm.EventDataFrame()

Expand Down Expand Up @@ -266,7 +266,7 @@ def transform(
kwargs['sampling_rate'] = self.experiment.sampling_rate

if 'n_components' in method_kwargs and 'n_components' not in kwargs:
_check_n_components(self.n_components)
self._check_n_components()
kwargs['n_components'] = self.n_components

if self.trial_columns is None:
Expand Down Expand Up @@ -380,46 +380,8 @@ def detect(
if isinstance(method, str):
method = pm.events.EventDetectionLibrary.get(method)

# Automatically infer eye to use for event detection.
_check_n_components(self.n_components)
if eye == 'auto':
if self.n_components == 6:
eye_components = 4, 5
elif self.n_components == 4:
eye_components = 2, 3
elif self.n_components == 2:
eye_components = 0, 1
else:
raise AttributeError()

method_args = inspect.getfullargspec(method).args

if 'positions' in method_args:
positions = np.vstack(
[
self.frame.get_column('position').list.get(eye_component)
for eye_component in eye_components
],
).transpose()
kwargs['positions'] = positions

if 'velocities' in method_args:
velocities = np.vstack(
[
self.frame.get_column('velocity').list.get(eye_component)
for eye_component in eye_components
],
).transpose()
kwargs['velocities'] = velocities

if 'events' in method_args:
kwargs['events'] = self.events

if 'timesteps' in method_args and 'time' in self.frame.columns:
timesteps = self.frame.get_column('time').to_numpy()
kwargs['timesteps'] = timesteps

new_events = method(**kwargs)
method_kwargs = self._fill_event_detection_kwargs(method, eye, **kwargs)
new_events = method(**method_kwargs)

self.events.frame = pl.concat(
[self.events.frame, new_events.frame],
Expand Down Expand Up @@ -452,7 +414,7 @@ def nest(
output_column:
Name of the resulting tuple column.
"""
_check_component_columns(frame=self.frame, **{output_column: input_columns})
self._check_component_columns(**{output_column: input_columns})

self.frame = self.frame.with_columns(
pl.concat_list([pl.col(component) for component in input_columns])
Expand Down Expand Up @@ -492,7 +454,7 @@ def unnest(
output_columns=output_columns,
output_suffixes=output_suffixes,
)
_check_n_components(self.n_components)
self._check_n_components()

col_names = output_columns if output_columns is not None else []

Expand Down Expand Up @@ -542,89 +504,154 @@ def _check_experiment(self) -> None:
if self.experiment is None:
raise AttributeError('experiment must not be None for this method to work')


def _check_component_columns(
frame: pl.DataFrame,
**kwargs: list[str],
) -> None:
"""Check if component columns are in valid format."""
for component_type, columns in kwargs.items():
if not isinstance(columns, list):
raise TypeError(
f'{component_type} must be of type list, but is of type {type(columns).__name__}',
def _check_n_components(self) -> None:
"""Check that n_components is either 2, 4 or 6."""
if self.n_components not in {2, 4, 6}:
raise AttributeError(
f'n_components must be either 2, 4 or 6 but is {self.n_components}',
)

for column in columns:
if not isinstance(column, str):
def _check_component_columns(self, **kwargs: list[str]) -> None:
"""Check if component columns are in valid format."""
for component_type, columns in kwargs.items():
if not isinstance(columns, list):
raise TypeError(
f'all elements in {component_type} must be of type str, '
f'but one of the elements is of type {type(column).__name__}',
f'{component_type} must be of type list, '
f'but is of type {type(columns).__name__}',
)

if len(columns) not in [2, 4, 6]:
raise ValueError(
f'{component_type} must contain either 2, 4 or 6 columns, but has {len(columns)}',
)
for column in columns:
if not isinstance(column, str):
raise TypeError(
f'all elements in {component_type} must be of type str, '
f'but one of the elements is of type {type(column).__name__}',
)

if len(columns) not in [2, 4, 6]:
raise ValueError(
f'{component_type} must contain either 2, 4 or 6 columns, '
f'but has {len(columns)}',
)

for column in columns:
if column not in frame.columns:
raise pl.exceptions.ColumnNotFoundError(
f'column {column} from {component_type} is not available in dataframe',
for column in columns:
if column not in self.frame.columns:
raise pl.exceptions.ColumnNotFoundError(
f'column {column} from {component_type} is not available in dataframe',
)

if len(set(self.frame[columns].dtypes)) != 1:
types_list = sorted([str(t) for t in set(self.frame[columns].dtypes)])
raise ValueError(
f'all columns in {component_type} must be of same type, '
f'but types are {types_list}',
)

if len(set(frame[columns].dtypes)) != 1:
types_list = sorted([str(t) for t in set(frame[columns].dtypes)])
raise ValueError(
f'all columns in {component_type} must be of same type, but types are {types_list}',
)
def _infer_n_components(self, column_specifiers: list[list[str]]) -> int | None:
"""Infer number of components from DataFrame.
Method checks nested columns `pixel`, `position`, `velocity` and `acceleration` for number
of components by getting their list lenghts, which must be equal for all else a ValueError
is raised. Additionally, a list of list of column specifiers is checked for consistency.
Parameters
----------
frame: pl.DataFrame
DataFrame to check.
column_specifiers:
List of list of column specifiers.
Returns
-------
int or None
Number of components
Raises
------
ValueError
If number of components is not equal for all considered columns and rows.
"""
all_considered_columns = ['pixel', 'position', 'velocity', 'acceleration']
considered_columns = [
column for column in all_considered_columns if column in self.frame.columns
]

def _check_n_components(n_components: Any) -> None:
"""Check that n_components is either 2, 4 or 6."""
if n_components not in {2, 4, 6}:
raise AttributeError(f'n_components must be either 2, 4 or 6 but is {n_components}')
list_lengths = {
list_length
for column in considered_columns
for list_length in self.frame.get_column(column).list.lengths().unique().to_list()
}

for column_specifier_list in column_specifiers:
list_lengths.add(len(column_specifier_list))

def _infer_n_components(frame: pl.DataFrame, column_specifiers: list[list[str]]) -> int | None:
"""Infer number of components from DataFrame.
if len(list_lengths) > 1:
raise ValueError(f'inconsistent number of components inferred: {list_lengths}')

Method checks nested columns `pixel`, `position`, `velocity` and `acceleration` for number of
components by getting their list lenghts, which must be equal for all else a ValueError is
raised. Additionally, a list of list of column specifiers is checked for consistency.
if len(list_lengths) == 0:
return None

Parameters
----------
frame: pl.DataFrame
DataFrame to check.
column_specifiers:
List of list of column specifiers.
return next(iter(list_lengths))

Returns
-------
int or None
Number of components
def _infer_eye_components(self, eye: str) -> tuple[int, int]:
self._check_n_components()

Raises
------
ValueError
If number of components is not equal for all considered columns and rows.
"""
all_considered_columns = ['pixel', 'position', 'velocity', 'acceleration']
considered_columns = [column for column in all_considered_columns if column in frame.columns]
if eye == 'auto':
# Order of inference: cyclops, right, left
if self.n_components == 6:
eye_components = 4, 5
elif self.n_components == 4:
eye_components = 2, 3
else: # We already checked validity, must be 2.
eye_components = 0, 1
elif eye == 'left':
if self.n_components < 4: # Left only makes sense if there are at least two eyes.
raise AttributeError()
eye_components = 0, 1
elif eye == 'right':
if self.n_components < 4: # Right only makes sense if there are at least two eyes.
raise AttributeError()
eye_components = 2, 3
elif eye == 'cyclops':
if self.n_components < 6:
raise AttributeError()
eye_components = 4, 5
elif eye == 'mono':
eye_components = 0, 1
else:
raise ValueError()

return eye_components

list_lengths = {
list_length
for column in considered_columns
for list_length in frame.get_column(column).list.lengths().unique().to_list()
}
def _fill_event_detection_kwargs(
self,
method: Callable[..., pm.EventDataFrame],
eye: str,
**kwargs,
):
# Automatically infer eye to use for event detection.
eye_components = self._infer_eye_components(eye)
method_args = inspect.getfullargspec(method).args

if 'positions' in method_args:
kwargs['positions'] = np.vstack(
[
self.frame.get_column('position').list.get(eye_component)
for eye_component in eye_components
],
).transpose()

for column_specifier_list in column_specifiers:
list_lengths.add(len(column_specifier_list))
if 'velocities' in method_args:
kwargs['velocities'] = np.vstack(
[
self.frame.get_column('velocity').list.get(eye_component)
for eye_component in eye_components
],
).transpose()

if len(list_lengths) > 1:
raise ValueError(f'inconsistent number of components inferred: {list_lengths}')
if 'events' in method_args:
kwargs['events'] = self.events

if len(list_lengths) == 0:
return None
if 'timesteps' in method_args and 'time' in self.frame.columns:
kwargs['timesteps'] = self.frame.get_column('time').to_numpy()

return next(iter(list_lengths))
return kwargs
2 changes: 0 additions & 2 deletions tests/events/detection/ivt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,6 @@ def test_ivt_detects_fixations(kwargs, expected):
kwargs['positions'], sampling_rate=10, method='preceding',
)

print(velocities)
assert False
# Just use positions argument for velocity calculation
kwargs.pop('positions')

Expand Down

0 comments on commit b8f6db2

Please sign in to comment.