Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
dkrako committed Sep 14, 2023
1 parent d4ac1fd commit 61bd27c
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 116 deletions.
2 changes: 1 addition & 1 deletion src/pymovements/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def detect(
self._check_gaze_dataframe()

if not self.events:
self.events = [None for _ in self.gaze]
self.events = [gaze.events for gaze in self.gaze]

disable_progressbar = not verbose
for file_id, (gaze, fileinfo_row) in tqdm(
Expand Down
263 changes: 150 additions & 113 deletions src/pymovements/gaze/gaze_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,26 +197,26 @@ def __init__(
column_specifiers: list[list[str]] = []

if pixel_columns:
_check_component_columns(self.frame, pixel_columns=pixel_columns)
self._check_component_columns(pixel_columns=pixel_columns)
self.nest(pixel_columns, output_column='pixel')
column_specifiers.append(pixel_columns)

if position_columns:
_check_component_columns(self.frame, position_columns=position_columns)
self._check_component_columns(position_columns=position_columns)
self.nest(position_columns, output_column='position')
column_specifiers.append(position_columns)

if velocity_columns:
_check_component_columns(self.frame, velocity_columns=velocity_columns)
self._check_component_columns(velocity_columns=velocity_columns)
self.nest(velocity_columns, output_column='velocity')
column_specifiers.append(velocity_columns)

if acceleration_columns:
_check_component_columns(self.frame, acceleration_columns=acceleration_columns)
self._check_component_columns(acceleration_columns=acceleration_columns)
self.nest(acceleration_columns, output_column='acceleration')
column_specifiers.append(acceleration_columns)

self.n_components = _infer_n_components(self.frame, column_specifiers)
self.n_components = self._infer_n_components(column_specifiers)
self.experiment = experiment

if events is None:
Expand Down Expand Up @@ -273,7 +273,7 @@ def transform(
kwargs['sampling_rate'] = self.experiment.sampling_rate

if 'n_components' in method_kwargs and 'n_components' not in kwargs:
_check_n_components(self.n_components)
self._check_n_components()
kwargs['n_components'] = self.n_components

if transform_method.__name__ in {'pos2vel', 'pos2acc'}:
Expand Down Expand Up @@ -388,7 +388,7 @@ def detect(
self,
method: Callable[..., pm.EventDataFrame] | str,
*,
eye: str | None = 'auto',
eye: str = 'auto',
clear: bool = False,
**kwargs: Any,
) -> None:
Expand All @@ -414,46 +414,8 @@ def detect(
if isinstance(method, str):
method = pm.events.EventDetectionLibrary.get(method)

# Automatically infer eye to use for event detection.
_check_n_components(self.n_components)
if eye == 'auto':
if self.n_components == 6:
eye_components = 4, 5
elif self.n_components == 4:
eye_components = 2, 3
elif self.n_components == 2:
eye_components = 0, 1
else:
raise AttributeError()

method_args = inspect.getfullargspec(method).args

if 'positions' in method_args:
positions = np.vstack(
[
self.frame.get_column('position').list.get(eye_component)
for eye_component in eye_components
],
).transpose()
kwargs['positions'] = positions

if 'velocities' in method_args:
velocities = np.vstack(
[
self.frame.get_column('velocity').list.get(eye_component)
for eye_component in eye_components
],
).transpose()
kwargs['velocities'] = velocities

if 'events' in method_args:
kwargs['events'] = self.events

if 'timesteps' in method_args and 'time' in self.frame.columns:
timesteps = self.frame.get_column('time').to_numpy()
kwargs['timesteps'] = timesteps

new_events = method(**kwargs)
method_kwargs = self._fill_event_detection_kwargs(method, eye, **kwargs)
new_events = method(**method_kwargs)

self.events.frame = pl.concat(
[self.events.frame, new_events.frame],
Expand Down Expand Up @@ -486,6 +448,8 @@ def nest(
output_column:
Name of the resulting tuple column.
"""
self._check_component_columns(**{output_column: input_columns})

self.frame = self.frame.with_columns(
pl.concat_list([pl.col(component) for component in input_columns])
.alias(output_column),
Expand Down Expand Up @@ -524,7 +488,7 @@ def unnest(
output_columns=output_columns,
output_suffixes=output_suffixes,
)
_check_n_components(self.n_components)
self._check_n_components()

col_names = output_columns if output_columns is not None else []

Expand Down Expand Up @@ -574,89 +538,162 @@ def _check_experiment(self) -> None:
if self.experiment is None:
raise AttributeError('experiment must not be None for this method to work')


def _check_component_columns(
frame: pl.DataFrame,
**kwargs: list[str],
) -> None:
"""Check if component columns are in valid format."""
for component_type, columns in kwargs.items():
if not isinstance(columns, list):
raise TypeError(
f'{component_type} must be of type list, but is of type {type(columns).__name__}',
def _check_n_components(self) -> None:
"""Check that n_components is either 2, 4 or 6."""
if self.n_components not in {2, 4, 6}:
raise AttributeError(
f'n_components must be either 2, 4 or 6 but is {self.n_components}',
)

for column in columns:
if not isinstance(column, str):
def _check_component_columns(self, **kwargs: list[str]) -> None:
"""Check if component columns are in valid format."""
for component_type, columns in kwargs.items():
if not isinstance(columns, list):
raise TypeError(
f'all elements in {component_type} must be of type str, '
f'but one of the elements is of type {type(column).__name__}',
f'{component_type} must be of type list, '
f'but is of type {type(columns).__name__}',
)

if len(columns) not in [2, 4, 6]:
raise ValueError(
f'{component_type} must contain either 2, 4 or 6 columns, but has {len(columns)}',
)
for column in columns:
if not isinstance(column, str):
raise TypeError(
f'all elements in {component_type} must be of type str, '
f'but one of the elements is of type {type(column).__name__}',
)

for column in columns:
if column not in frame.columns:
raise pl.exceptions.ColumnNotFoundError(
f'column {column} from {component_type} is not available in dataframe',
if len(columns) not in [2, 4, 6]:
raise ValueError(
f'{component_type} must contain either 2, 4 or 6 columns, '
f'but has {len(columns)}',
)

if len(set(frame[columns].dtypes)) != 1:
types_list = sorted([str(t) for t in set(frame[columns].dtypes)])
raise ValueError(
f'all columns in {component_type} must be of same type, but types are {types_list}',
)
for column in columns:
if column not in self.frame.columns:
raise pl.exceptions.ColumnNotFoundError(
f'column {column} from {component_type} is not available in dataframe',
)

if len(set(self.frame[columns].dtypes)) != 1:
types_list = sorted([str(t) for t in set(self.frame[columns].dtypes)])
raise ValueError(
f'all columns in {component_type} must be of same type, '
f'but types are {types_list}',
)

def _infer_n_components(self, column_specifiers: list[list[str]]) -> int | None:
"""Infer number of components from DataFrame.
Method checks nested columns `pixel`, `position`, `velocity` and `acceleration` for number
of components by getting their list lenghts, which must be equal for all else a ValueError
is raised. Additionally, a list of list of column specifiers is checked for consistency.
Parameters
----------
column_specifiers:
List of list of column specifiers.
Returns
-------
int or None
Number of components
def _check_n_components(n_components: Any) -> None:
"""Check that n_components is either 2, 4 or 6."""
if n_components not in {2, 4, 6}:
raise AttributeError(f'n_components must be either 2, 4 or 6 but is {n_components}')
Raises
------
ValueError
If number of components is not equal for all considered columns and rows.
"""
all_considered_columns = ['pixel', 'position', 'velocity', 'acceleration']
considered_columns = [
column for column in all_considered_columns if column in self.frame.columns
]

list_lengths = {
list_length
for column in considered_columns
for list_length in self.frame.get_column(column).list.lengths().unique().to_list()
}

def _infer_n_components(frame: pl.DataFrame, column_specifiers: list[list[str]]) -> int | None:
"""Infer number of components from DataFrame.
for column_specifier_list in column_specifiers:
list_lengths.add(len(column_specifier_list))

Method checks nested columns `pixel`, `position`, `velocity` and `acceleration` for number of
components by getting their list lenghts, which must be equal for all else a ValueError is
raised. Additionally, a list of list of column specifiers is checked for consistency.
if len(list_lengths) > 1:
raise ValueError(f'inconsistent number of components inferred: {list_lengths}')

Parameters
----------
frame: pl.DataFrame
DataFrame to check.
column_specifiers:
List of list of column specifiers.
if len(list_lengths) == 0:
return None

Returns
-------
int or None
Number of components
return next(iter(list_lengths))

Raises
------
ValueError
If number of components is not equal for all considered columns and rows.
"""
all_considered_columns = ['pixel', 'position', 'velocity', 'acceleration']
considered_columns = [column for column in all_considered_columns if column in frame.columns]
def _infer_eye_components(self, eye: str) -> tuple[int, int]:
"""Infer eye components from eye string.
list_lengths = {
list_length
for column in considered_columns
for list_length in frame.get_column(column).list.lengths().unique().to_list()
}
Parameters
----------
eye: str
String specificer for inferring eye components. Supported values are: auto, mono, left
right, cyclops. Default: auto.
"""
self._check_n_components()

for column_specifier_list in column_specifiers:
list_lengths.add(len(column_specifier_list))
if eye == 'auto':
# Order of inference: cyclops, right, left
if self.n_components == 6:
eye_components = 4, 5
elif self.n_components == 4:
eye_components = 2, 3
else: # We already checked validity, must be 2.
eye_components = 0, 1
elif eye == 'left':
if isinstance(self.n_components, int) and self.n_components < 4:
# Left only makes sense if there are at least two eyes.
raise AttributeError()
eye_components = 0, 1
elif eye == 'right':
if isinstance(self.n_components, int) and self.n_components < 4:
# Right only makes sense if there are at least two eyes.
raise AttributeError()
eye_components = 2, 3
elif eye == 'cyclops':
if isinstance(self.n_components, int) and self.n_components < 6:
raise AttributeError()
eye_components = 4, 5
elif eye == 'mono':
eye_components = 0, 1
else:
raise ValueError()

if len(list_lengths) > 1:
raise ValueError(f'inconsistent number of components inferred: {list_lengths}')
return eye_components

if len(list_lengths) == 0:
return None
def _fill_event_detection_kwargs(
self,
method: Callable[..., pm.EventDataFrame],
eye: str,
**kwargs: Any,
) -> dict[str, Any]:
# Automatically infer eye to use for event detection.
eye_components = self._infer_eye_components(eye)
method_args = inspect.getfullargspec(method).args

if 'positions' in method_args:
kwargs['positions'] = np.vstack(
[
self.frame.get_column('position').list.get(eye_component)
for eye_component in eye_components
],
).transpose()

if 'velocities' in method_args:
kwargs['velocities'] = np.vstack(
[
self.frame.get_column('velocity').list.get(eye_component)
for eye_component in eye_components
],
).transpose()

if 'events' in method_args:
kwargs['events'] = self.events

if 'timesteps' in method_args and 'time' in self.frame.columns:
kwargs['timesteps'] = self.frame.get_column('time').to_numpy()

return next(iter(list_lengths))
return kwargs
2 changes: 0 additions & 2 deletions tests/events/detection/ivt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,6 @@ def test_ivt_detects_fixations(kwargs, expected):
kwargs['positions'], sampling_rate=10, method='preceding',
)

print(velocities)
assert False
# Just use positions argument for velocity calculation
kwargs.pop('positions')

Expand Down

0 comments on commit 61bd27c

Please sign in to comment.