Skip to content

Commit

Permalink
refactor: identify datetime format manually only when polars failed
Browse files Browse the repository at this point in the history
  • Loading branch information
polinaeterna committed Jan 15, 2025
1 parent 8afade1 commit 341676c
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions services/worker/src/worker/statistics_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,22 +478,11 @@ def is_class(n_unique: int, n_samples: int) -> bool:
) or n_unique <= NUM_BINS

@staticmethod
def is_datetime(data: pl.DataFrame, column_name: str) -> tuple[bool, Optional[str]]:
"""Check if first 100 non-null samples in a column match datetime format. If true, also return datetime format"""
def is_datetime(data: pl.DataFrame, column_name: str) -> bool:
"""Check if first 100 non-null samples in a column match datetime format."""

values = data.filter(pl.col(column_name).is_not_null()).head(100)[column_name].to_list()
_is_datetime = all(is_datetime(value) for value in values) if len(values) > 0 else False

if _is_datetime:
formats = [identify_datetime_format(value) for value in values]
if len(set(formats)) == 1:
datetime_format = formats[0]
if not datetime_format:
raise ValueError("Values are datetime but format is not identified")
return True, datetime_format
raise StatisticsComputationError("Multiple datetime formats detected. ")

return False, None
return all(is_datetime(value) for value in values) if len(values) > 0 else False

@classmethod
def compute_transformed_data(
Expand All @@ -515,13 +504,11 @@ def _compute_statistics(
) -> Union[CategoricalStatisticsItem, NumericalStatisticsItem, DatetimeStatisticsItem]:
nan_count, nan_proportion = nan_count_proportion(data, column_name, n_samples)
n_unique = data[column_name].n_unique()
_is_datetime, datetime_format = cls.is_datetime(data, column_name)
if _is_datetime:
if cls.is_datetime(data, column_name):
datetime_stats: DatetimeStatisticsItem = DatetimeColumn.compute_statistics(
data,
column_name=column_name,
n_samples=n_samples,
format=datetime_format,
)
return datetime_stats

Expand Down Expand Up @@ -772,13 +759,27 @@ def compute_transformed_data(
def shift_and_convert_to_string(base_date: datetime.datetime, seconds: Union[int, float]) -> str:
return datetime_to_string(base_date + datetime.timedelta(seconds=seconds))

@staticmethod
def get_format(data: pl.DataFrame, column_name: str) -> str:
values = data.filter(pl.col(column_name).is_not_null()).head(100)[column_name].to_list()
formats = [identify_datetime_format(value) for value in values]
if len(set(formats)) == 1:
datetime_format = formats[0]
if not datetime_format:
raise StatisticsComputationError(
f"Values are datetime but format is not identified. Example: {values[0]}. "
)
else:
raise StatisticsComputationError("Multiple datetime formats detected. ")

return datetime_format

@classmethod
def _compute_statistics(
cls,
data: pl.DataFrame,
column_name: str,
n_samples: int,
format: Optional[str] = None,
) -> DatetimeStatisticsItem:
nan_count, nan_proportion = nan_count_proportion(data, column_name, n_samples)
if nan_count == n_samples: # all values are None
Expand All @@ -799,7 +800,8 @@ def _compute_statistics(
try:
data = data.with_columns(pl.col(column_name).str.to_datetime())
except pl.ComputeError:
data = data.with_columns(pl.col(column_name).str.to_datetime(format=format))
datetime_format = cls.get_format(data, column_name)
data = data.with_columns(pl.col(column_name).str.to_datetime(format=datetime_format))

min_date: datetime.datetime = data[column_name].min() # type: ignore # mypy infers type of datetime column .min() incorrectly
timedelta_column_name = f"{column_name}_timedelta"
Expand Down Expand Up @@ -838,10 +840,8 @@ def _compute_statistics(
),
)

def compute_and_prepare_response(
self, data: pl.DataFrame, format: Optional[str] = None
) -> StatisticsPerColumnItem:
stats = self.compute_statistics(data, column_name=self.name, n_samples=self.n_samples, format=format)
def compute_and_prepare_response(self, data: pl.DataFrame) -> StatisticsPerColumnItem:
stats = self.compute_statistics(data, column_name=self.name, n_samples=self.n_samples)
return StatisticsPerColumnItem(
column_name=self.name,
column_type=ColumnType.DATETIME,
Expand Down

0 comments on commit 341676c

Please sign in to comment.