Skip to content

Commit

Permalink
fix: Specify column dtypes in dataset definitions (#593)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkrako authored Oct 10, 2023
1 parent 78b649c commit 8836275
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 17 deletions.
12 changes: 12 additions & 0 deletions src/pymovements/datasets/gazebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from dataclasses import field
from typing import Any

import polars as pl

from pymovements.dataset.dataset_definition import DatasetDefinition
from pymovements.dataset.dataset_library import register_dataset
from pymovements.gaze.experiment import Experiment
Expand Down Expand Up @@ -157,5 +159,15 @@ class GazeBase(DatasetDefinition):
custom_read_kwargs: dict[str, Any] = field(
default_factory=lambda: {
'null_values': 'NaN',
'dtypes': {
'n': pl.Int64,
'x': pl.Float32,
'y': pl.Float32,
'val': pl.Int32,
'dP': pl.Float32,
'lab': pl.Int32,
'xT': pl.Float32,
'yT': pl.Float32,
},
},
)
25 changes: 24 additions & 1 deletion src/pymovements/datasets/gazebasevr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from dataclasses import field
from typing import Any

import polars as pl

from pymovements.dataset.dataset_definition import DatasetDefinition
from pymovements.dataset.dataset_library import register_dataset
from pymovements.gaze.experiment import Experiment
Expand Down Expand Up @@ -154,4 +156,25 @@ class GazeBaseVR(DatasetDefinition):
},
)

custom_read_kwargs: dict[str, Any] = field(default_factory=dict)
custom_read_kwargs: dict[str, Any] = field(
default_factory=lambda: {
'dtypes': {
'n': pl.Float32,
'x': pl.Float32,
'y': pl.Float32,
'lx': pl.Float32,
'ly': pl.Float32,
'rx': pl.Float32,
'ry': pl.Float32,
'xT': pl.Float32,
'yT': pl.Float32,
'zT': pl.Float32,
'clx': pl.Float32,
'cly': pl.Float32,
'clz': pl.Float32,
'crx': pl.Float32,
'cry': pl.Float32,
'crz': pl.Float32,
},
},
)
12 changes: 6 additions & 6 deletions src/pymovements/datasets/hbn.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,11 @@ class HBN(DatasetDefinition):
custom_read_kwargs: dict[str, Any] = field(
default_factory=lambda: {
'separator': ',',
'columns': [
'time', 'x_pix', 'y_pix',
],
'dtypes': [
pl.Float64, pl.Float64, pl.Float64,
],
'columns': ['time', 'x_pix', 'y_pix'],
'dtypes': {
'time': pl.Int64,
'x_pix': pl.Float32,
'y_pix': pl.Float32,
},
},
)
14 changes: 13 additions & 1 deletion src/pymovements/datasets/judo1000.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

from dataclasses import dataclass
from dataclasses import field
from typing import Any

import polars as pl

from pymovements.dataset.dataset_definition import DatasetDefinition
from pymovements.dataset.dataset_library import register_dataset
Expand Down Expand Up @@ -144,8 +147,17 @@ class JuDo1000(DatasetDefinition):
},
)

custom_read_kwargs: dict[str, str] = field(
custom_read_kwargs: dict[str, Any] = field(
default_factory=lambda: {
'dtypes': {
'trialId': pl.Int32,
'pointId': pl.Int32,
'time': pl.Int64,
'x_left': pl.Float32,
'y_left': pl.Float32,
'x_right': pl.Float32,
'y_right': pl.Float32,
},
'separator': '\t',
},
)
17 changes: 9 additions & 8 deletions src/pymovements/datasets/sb_sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,14 @@ class SBSAT(DatasetDefinition):
custom_read_kwargs: dict[str, Any] = field(
default_factory=lambda: {
'separator': '\t',
'columns': [
'time', 'book_name', 'screen_id',
'x_left', 'y_left', 'pupil_left',
],
'dtypes': [
pl.Int64, pl.Utf8, pl.Int64,
pl.Float64, pl.Float64, pl.Float64,
],
'columns': ['time', 'book_name', 'screen_id', 'x_left', 'y_left', 'pupil_left'],
'dtypes': {
'time': pl.Int64,
'book_name': pl.Utf8,
'screen_id': pl.Int32,
'x_left': pl.Float32,
'y_left': pl.Float32,
'pupil_left': pl.Float32,
},
},
)
13 changes: 12 additions & 1 deletion src/pymovements/datasets/toy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

from dataclasses import dataclass
from dataclasses import field
from typing import Any

import polars as pl

from pymovements.dataset.dataset_definition import DatasetDefinition
from pymovements.dataset.dataset_library import register_dataset
Expand Down Expand Up @@ -132,8 +135,16 @@ class ToyDataset(DatasetDefinition):

column_map: dict[str, str] = field(default_factory=lambda: {})

custom_read_kwargs: dict[str, str] = field(
custom_read_kwargs: dict[str, Any] = field(
default_factory=lambda: {
'columns': ['timestamp', 'x', 'y', 'stimuli_x', 'stimuli_y'],
'dtypes': {
'timestamp': pl.Float32,
'x': pl.Float32,
'y': pl.Float32,
'stimuli_x': pl.Float32,
'stimuli_y': pl.Float32,
},
'separator': '\t',
'null_values': '-32768.00',
},
Expand Down

0 comments on commit 8836275

Please sign in to comment.