diff --git a/optuna/storages/_rdb/models.py b/optuna/storages/_rdb/models.py index c9320b4f88..ef435e9ba7 100644 --- a/optuna/storages/_rdb/models.py +++ b/optuna/storages/_rdb/models.py @@ -1,9 +1,8 @@ +from __future__ import annotations + import enum import math from typing import Any -from typing import List -from typing import Optional -from typing import Tuple from sqlalchemy import asc from sqlalchemy import case @@ -77,7 +76,7 @@ def find_or_raise_by_id( return study @classmethod - def find_by_name(cls, study_name: str, session: orm.Session) -> Optional["StudyModel"]: + def find_by_name(cls, study_name: str, session: orm.Session) -> "StudyModel" | None: study = session.query(cls).filter(cls.study_name == study_name).one_or_none() return study @@ -104,7 +103,7 @@ class StudyDirectionModel(BaseModel): ) @classmethod - def where_study_id(cls, study_id: int, session: orm.Session) -> List["StudyDirectionModel"]: + def where_study_id(cls, study_id: int, session: orm.Session) -> list["StudyDirectionModel"]: return session.query(cls).filter(cls.study_id == study_id).all() @@ -123,7 +122,7 @@ class StudyUserAttributeModel(BaseModel): @classmethod def find_by_study_and_key( cls, study: StudyModel, key: str, session: orm.Session - ) -> Optional["StudyUserAttributeModel"]: + ) -> "StudyUserAttributeModel" | None: attribute = ( session.query(cls) .filter(cls.study_id == study.study_id) @@ -136,7 +135,7 @@ def find_by_study_and_key( @classmethod def where_study_id( cls, study_id: int, session: orm.Session - ) -> List["StudyUserAttributeModel"]: + ) -> list["StudyUserAttributeModel"]: return session.query(cls).filter(cls.study_id == study_id).all() @@ -155,7 +154,7 @@ class StudySystemAttributeModel(BaseModel): @classmethod def find_by_study_and_key( cls, study: StudyModel, key: str, session: orm.Session - ) -> Optional["StudySystemAttributeModel"]: + ) -> "StudySystemAttributeModel" | None: attribute = ( session.query(cls) .filter(cls.study_id == study.study_id) @@ -168,7 +167,7 @@ def find_by_study_and_key( @classmethod def where_study_id( cls, study_id: int, session: orm.Session - ) -> List["StudySystemAttributeModel"]: + ) -> list["StudySystemAttributeModel"]: return session.query(cls).filter(cls.study_id == study_id).all() @@ -259,10 +258,7 @@ def find_or_raise_by_id( @classmethod def count( - cls, - session: orm.Session, - study: Optional[StudyModel] = None, - state: Optional[TrialState] = None, + cls, session: orm.Session, study: StudyModel | None = None, state: TrialState | None = None ) -> int: trial_count = session.query(func.count(cls.trial_id)) if study is not None: @@ -294,7 +290,7 @@ class TrialUserAttributeModel(BaseModel): @classmethod def find_by_trial_and_key( cls, trial: TrialModel, key: str, session: orm.Session - ) -> Optional["TrialUserAttributeModel"]: + ) -> "TrialUserAttributeModel" | None: attribute = ( session.query(cls) .filter(cls.trial_id == trial.trial_id) @@ -307,7 +303,7 @@ def find_by_trial_and_key( @classmethod def where_trial_id( cls, trial_id: int, session: orm.Session - ) -> List["TrialUserAttributeModel"]: + ) -> list["TrialUserAttributeModel"]: return session.query(cls).filter(cls.trial_id == trial_id).all() @@ -326,7 +322,7 @@ class TrialSystemAttributeModel(BaseModel): @classmethod def find_by_trial_and_key( cls, trial: TrialModel, key: str, session: orm.Session - ) -> Optional["TrialSystemAttributeModel"]: + ) -> "TrialSystemAttributeModel" | None: attribute = ( session.query(cls) .filter(cls.trial_id == trial.trial_id) @@ -339,7 +335,7 @@ def find_by_trial_and_key( @classmethod def where_trial_id( cls, trial_id: int, session: orm.Session - ) -> List["TrialSystemAttributeModel"]: + ) -> list["TrialSystemAttributeModel"]: return session.query(cls).filter(cls.trial_id == trial_id).all() @@ -381,7 +377,7 @@ def _check_compatibility_with_previous_trial_param_distributions( @classmethod def find_by_trial_and_param_name( cls, trial: TrialModel, param_name: str, session: orm.Session - ) -> Optional["TrialParamModel"]: + ) -> "TrialParamModel" | None: param_distribution = ( session.query(cls) .filter(cls.trial_id == trial.trial_id) @@ -403,7 +399,7 @@ def find_or_raise_by_trial_and_param_name( return param_distribution @classmethod - def where_trial_id(cls, trial_id: int, session: orm.Session) -> List["TrialParamModel"]: + def where_trial_id(cls, trial_id: int, session: orm.Session) -> list["TrialParamModel"]: trial_params = session.query(cls).filter(cls.trial_id == trial_id).all() return trial_params @@ -428,10 +424,7 @@ class TrialValueType(enum.Enum): ) @classmethod - def value_to_stored_repr( - cls, - value: float, - ) -> Tuple[Optional[float], TrialValueType]: + def value_to_stored_repr(cls, value: float) -> tuple[float | None, TrialValueType]: if value == float("inf"): return (None, cls.TrialValueType.INF_POS) elif value == float("-inf"): @@ -440,7 +433,7 @@ def value_to_stored_repr( return (value, cls.TrialValueType.FINITE) @classmethod - def stored_repr_to_value(cls, value: Optional[float], float_type: TrialValueType) -> float: + def stored_repr_to_value(cls, value: float | None, float_type: TrialValueType) -> float: if float_type == cls.TrialValueType.INF_POS: assert value is None return float("inf") @@ -455,7 +448,7 @@ def stored_repr_to_value(cls, value: Optional[float], float_type: TrialValueType @classmethod def find_by_trial_and_objective( cls, trial: TrialModel, objective: int, session: orm.Session - ) -> Optional["TrialValueModel"]: + ) -> "TrialValueModel" | None: trial_value = ( session.query(cls) .filter(cls.trial_id == trial.trial_id) @@ -466,7 +459,7 @@ def find_by_trial_and_objective( return trial_value @classmethod - def where_trial_id(cls, trial_id: int, session: orm.Session) -> List["TrialValueModel"]: + def where_trial_id(cls, trial_id: int, session: orm.Session) -> list["TrialValueModel"]: trial_values = ( session.query(cls).filter(cls.trial_id == trial_id).order_by(asc(cls.objective)).all() ) @@ -495,9 +488,8 @@ class TrialIntermediateValueType(enum.Enum): @classmethod def intermediate_value_to_stored_repr( - cls, - value: float, - ) -> Tuple[Optional[float], TrialIntermediateValueType]: + cls, value: float + ) -> tuple[float | None, TrialIntermediateValueType]: if math.isnan(value): return (None, cls.TrialIntermediateValueType.NAN) elif value == float("inf"): @@ -509,7 +501,7 @@ def intermediate_value_to_stored_repr( @classmethod def stored_repr_to_intermediate_value( - cls, value: Optional[float], float_type: TrialIntermediateValueType + cls, value: float | None, float_type: TrialIntermediateValueType ) -> float: if float_type == cls.TrialIntermediateValueType.NAN: assert value is None @@ -528,7 +520,7 @@ def stored_repr_to_intermediate_value( @classmethod def find_by_trial_and_step( cls, trial: TrialModel, step: int, session: orm.Session - ) -> Optional["TrialIntermediateValueModel"]: + ) -> "TrialIntermediateValueModel" | None: trial_intermediate_value = ( session.query(cls) .filter(cls.trial_id == trial.trial_id) @@ -541,7 +533,7 @@ def find_by_trial_and_step( @classmethod def where_trial_id( cls, trial_id: int, session: orm.Session - ) -> List["TrialIntermediateValueModel"]: + ) -> list["TrialIntermediateValueModel"]: trial_intermediate_values = session.query(cls).filter(cls.trial_id == trial_id).all() return trial_intermediate_values @@ -559,9 +551,7 @@ class TrialHeartbeatModel(BaseModel): ) @classmethod - def where_trial_id( - cls, trial_id: int, session: orm.Session - ) -> Optional["TrialHeartbeatModel"]: + def where_trial_id(cls, trial_id: int, session: orm.Session) -> "TrialHeartbeatModel" | None: return session.query(cls).filter(cls.trial_id == trial_id).one_or_none() @@ -574,6 +564,6 @@ class VersionInfoModel(BaseModel): library_version = _Column(String(MAX_VERSION_LENGTH)) @classmethod - def find(cls, session: orm.Session) -> Optional["VersionInfoModel"]: + def find(cls, session: orm.Session) -> "VersionInfoModel" | None: version_info = session.query(cls).one_or_none() return version_info