Skip to content

Commit

Permalink
refactor(events)!: Refactor gaze event processing for input tuples (#422
Browse files Browse the repository at this point in the history
)
  • Loading branch information
dkrako authored May 26, 2023
1 parent d80420c commit d9a94ef
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 254 deletions.
51 changes: 25 additions & 26 deletions src/pymovements/events/event_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,31 +164,36 @@ def process(
for property_name in self.event_properties
}

# We need to create a new column here, which is a list of position tuples.
# For the intermediate time before the tuple will be the default format for
# positions, we create this column here and drop this column afterwards.
position_columns = tuple(gaze.position_columns[:2])
if position_columns:
position_component_expressions = [pl.col(component) for component in position_columns]
gaze.frame = gaze.frame.with_columns(
pl.concat_list(position_component_expressions)
.alias('position'),
)
velocity_columns = tuple(gaze.velocity_columns[:2])
if velocity_columns:
velocity_component_expressions = [pl.col(component) for component in velocity_columns]
gaze.frame = gaze.frame.with_columns(
pl.concat_list(velocity_component_expressions)
.alias('velocity'),
)

property_kwargs: dict[str, dict[str, Any]] = {
property_name: {} for property_name in property_expressions.keys()
}
for property_name, property_expression in property_expressions.items():
property_args = inspect.getfullargspec(property_expression).args
if 'velocity_columns' in property_args:
velocity_columns = tuple(gaze.velocity_columns[:2])
property_kwargs[property_name]['velocity_columns'] = velocity_columns

if 'position_columns' in property_args:
position_columns = tuple(gaze.position_columns[:2])
property_kwargs[property_name]['position_columns'] = position_columns
property_args = inspect.getfullargspec(property_expression).kwonlyargs

if 'position_column' in property_args:
# We need to create a new column here, which is a list of position tuples.
# For the intermediate time before the tuple will be the default format for
# positions, we create this column here and drop this column afterwards.
position_columns = tuple(gaze.position_columns[:2])
component_expressions = [pl.col(component) for component in position_columns]
gaze.frame = gaze.frame.with_columns(
pl.concat_list(component_expressions)
.alias('position'),
)
property_kwargs[property_name]['position_column'] = 'position'

if 'velocity_column' in property_args:
property_kwargs[property_name]['velocity_column'] = 'velocity'

result = (
gaze.frame.join(events.frame, on=identifiers)
.filter(pl.col('time').is_between(pl.col('onset'), pl.col('offset')))
Expand All @@ -198,20 +203,14 @@ def process(
property_expression(**property_kwargs[property_name])
.alias(property_name)
for property_name, property_expression in property_expressions.items()
if 'position_column' not in property_kwargs[property_name]
] + [
property_expression(**property_kwargs[property_name])
.alias(property_name)
.first() # Not sure why this is needed, an outer list is being created somehow.
for property_name, property_expression in property_expressions.items()
if 'position_column' in property_kwargs[property_name]

],
)
)

# If we created the position tuple column we drop it again.
# If we created the position and velocity tuple columns, we drop it again.
if 'position' in gaze.frame.columns:
gaze.frame.drop_in_place('position')
if 'velocity' in gaze.frame.columns:
gaze.frame.drop_in_place('velocity')

return result
167 changes: 80 additions & 87 deletions src/pymovements/events/event_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,115 +35,137 @@ def register_event_property(function: Callable) -> Callable:


@register_event_property
def duration() -> pl.Expr:
"""Duration of an event.
The duration is defined as the difference between offset time and onset time.
"""
return pl.col('offset') - pl.col('onset')


@register_event_property
def peak_velocity(velocity_columns: tuple[str, str] = ('x_vel', 'y_vel')) -> pl.Expr:
"""Peak velocity of an event.
def amplitude(
*,
position_column: str = 'position',
n_components: int = 2,
) -> pl.Expr:
"""Amplitude of an event.
Parameters
----------
velocity_columns
The column names of the pitch and yaw velocity components.
position_column
The column name of the position tuples.
n_components:
Number of positional components. Usually these are the two components yaw and pitch.
Raises
------
TypeError
If velocity_columns not of type tuple, velocity_columns not of length 2, or elements of
velocity_columns not of type str.
ValueError
If number of components is not 2.
"""
_check_velocity_columns(velocity_columns)
_check_has_two_componenents(n_components)

x_velocity = pl.col(velocity_columns[0])
y_velocity = pl.col(velocity_columns[1])
x_position = pl.col(position_column).arr.get(0)
y_position = pl.col(position_column).arr.get(1)

return (x_velocity.pow(2) + y_velocity.pow(2)).sqrt().max()
return (
(x_position.max() - x_position.min()).pow(2)
+ (y_position.max() - y_position.min()).pow(2)
).sqrt()


@register_event_property
def dispersion(position_columns: tuple[str, str] = ('x_pos', 'y_pos')) -> pl.Expr:
def dispersion(
*,
position_column: str = 'position',
n_components: int = 2,
) -> pl.Expr:
"""Dispersion of an event.
Parameters
----------
position_columns
The column names of the pitch and yaw position components.
position_column
The column name of the position tuples.
n_components:
Number of positional components. Usually these are the two components yaw and pitch.
Raises
------
TypeError
If position_columns not of type tuple, position_columns not of length 2, or elements of
position_columns not of type str.
ValueError
If number of components is not 2.
"""
_check_position_columns(position_columns)
_check_has_two_componenents(n_components)

x_position = pl.col(position_columns[0])
y_position = pl.col(position_columns[1])
x_position = pl.col(position_column).arr.get(0)
y_position = pl.col(position_column).arr.get(1)

return x_position.max() - x_position.min() + y_position.max() - y_position.min()


@register_event_property
def amplitude(position_columns: tuple[str, str] = ('x_pos', 'y_pos')) -> pl.Expr:
"""Amplitude of an event.
def disposition(
*,
position_column: str = 'position',
n_components: int = 2,
) -> pl.Expr:
"""Disposition of an event.
Parameters
----------
position_columns
The column names of the pitch and yaw position components.
position_column
The column name of the position tuples.
n_components:
Number of positional components. Usually these are the two components yaw and pitch.
Raises
------
TypeError
If position_columns not of type tuple, position_columns not of length 2, or elements of
position_columns not of type str.
"""
_check_position_columns(position_columns)
_check_has_two_componenents(n_components)

x_position = pl.col(position_columns[0])
y_position = pl.col(position_columns[1])
x_position = pl.col(position_column).arr.get(0)
y_position = pl.col(position_column).arr.get(1)

return (
(x_position.max() - x_position.min()).pow(2)
+ (y_position.max() - y_position.min()).pow(2)
(x_position.head(n=1) - x_position.reverse().head(n=1)).pow(2)
+ (y_position.head(n=1) - y_position.reverse().head(n=1)).pow(2)
).sqrt()


@register_event_property
def disposition(position_columns: tuple[str, str] = ('x_pos', 'y_pos')) -> pl.Expr:
"""Disposition of an event.
def duration() -> pl.Expr:
"""Duration of an event.
The duration is defined as the difference between offset time and onset time.
"""
return pl.col('offset') - pl.col('onset')


@register_event_property
def peak_velocity(
*,
velocity_column: str = 'velocity',
n_components: int = 2,
) -> pl.Expr:
"""Peak velocity of an event.
Parameters
----------
position_columns
The column names of the pitch and yaw position components.
velocity_column
The column name of the velocity tuples.
n_components:
Number of positional components. Usually these are the two components yaw and pitch.
Raises
------
TypeError
If position_columns not of type tuple, position_columns not of length 2, or elements of
position_columns not of type str.
ValueError
If number of components is not 2.
"""
_check_position_columns(position_columns)
_check_has_two_componenents(n_components)

x_position = pl.col(position_columns[0])
y_position = pl.col(position_columns[1])
x_velocity = pl.col(velocity_column).arr.get(0)
y_velocity = pl.col(velocity_column).arr.get(1)

return (
(x_position.head(n=1) - x_position.reverse().head(n=1)).pow(2)
+ (y_position.head(n=1) - y_position.reverse().head(n=1)).pow(2)
).sqrt()
return (x_velocity.pow(2) + y_velocity.pow(2)).sqrt().max()


@register_event_property
def position(
method: str = 'mean',
*,
position_column: str = 'position',
n_components: int = 2,
) -> pl.Expr:
Expand Down Expand Up @@ -186,40 +208,11 @@ def position(

component_expressions.append(expression_component)

return pl.concat_list(component_expressions)


def _check_position_columns(position_columns: tuple[str, str]) -> None:
"""Check if position_columns is of type tuple[str, str]."""
if not isinstance(position_columns, tuple):
raise TypeError(
'position_columns must be of type tuple[str, str]'
f' but is of type {type(position_columns).__name__}',
)
if len(position_columns) != 2:
raise TypeError(
f'position_columns must be of length of 2 but is of length {len(position_columns)}',
)
if not all(isinstance(velocity_column, str) for velocity_column in position_columns):
raise TypeError(
'position_columns must be of type tuple[str, str] but is '
f'tuple[{type(position_columns[0]).__name__}, {type(position_columns[1]).__name__}]',
)
# Not sure why first() is needed here, but an outer list is being created somehow.
return pl.concat_list(component_expressions).first()


def _check_velocity_columns(velocity_columns: tuple[str, str]) -> None:
"""Check if velocity_columns is of type tuple[str, str]."""
if not isinstance(velocity_columns, tuple):
raise TypeError(
'velocity_columns must be of type tuple[str, str]'
f' but is of type {type(velocity_columns).__name__}',
)
if len(velocity_columns) != 2:
raise TypeError(
f'velocity_columns must be of length of 2 but is of length {len(velocity_columns)}',
)
if not all(isinstance(velocity_column, str) for velocity_column in velocity_columns):
raise TypeError(
'velocity_columns must be of type tuple[str, str] but is '
f'tuple[{type(velocity_columns[0]).__name__}, {type(velocity_columns[1]).__name__}]',
)
def _check_has_two_componenents(n_components: int) -> None:
"""Check that number of componenents is two."""
if n_components != 2:
raise ValueError('data must have exactly two components')
42 changes: 41 additions & 1 deletion tests/events/event_processing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,14 +295,54 @@ def test_event_gaze_processor_init_exceptions(args, kwargs, exception, msg_subst
),
id='dispersion_single_event_complete_window',
),
pytest.param(
pl.from_dict(
{'subject_id': [1], 'onset': [0], 'offset': [10]},
schema={'subject_id': pl.Int64, 'onset': pl.Int64, 'offset': pl.Int64},
),
pl.from_dict(
{
'subject_id': np.ones(10),
'time': np.arange(10),
'x_vel': np.concatenate([np.arange(0.1, 1.1, 0.1)]),
'y_vel': np.concatenate([np.arange(0.1, 1.1, 0.1)]),
},
schema={
'subject_id': pl.Int64,
'time': pl.Int64,
'x_vel': pl.Float64,
'y_vel': pl.Float64,
},
),
{'event_properties': 'peak_velocity'},
{'identifiers': 'subject_id'},
pl.from_dict(
{
'subject_id': [1],
'name': [None],
'onset': [0],
'offset': [10],
'peak_velocity': [np.sqrt(2)],
},
schema={
'subject_id': pl.Int64,
'name': pl.Utf8,
'onset': pl.Int64,
'offset': pl.Int64,
'peak_velocity': pl.Float64,
},
),
id='peak_velocity_single_event_complete_window',
),
],
)
def test_event_gaze_processor_process_correct_result(
event_df, gaze_df, init_kwargs, process_kwargs, expected_dataframe,
):
processor = EventGazeProcessor(**init_kwargs)
events = EventDataFrame(event_df)
gaze = GazeDataFrame(gaze_df)

processor = EventGazeProcessor(**init_kwargs)
property_result = processor.process(events, gaze, **process_kwargs)
assert_frame_equal(property_result, expected_dataframe)

Expand Down
Loading

0 comments on commit d9a94ef

Please sign in to comment.