Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more ruff rules ICN, PIE #9931

Merged
merged 3 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ repos:
exclude: test-data/ert/eclipse/parse/ERROR.PRT # exact format is needed for testing

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.3
rev: v0.9.4
hooks:
- id: ruff
args: [ --fix ]
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ select = [
"ASYNC", # flake8-async
"RUF", # ruff specific rules
"UP", # pyupgrade
"ICN", # flake8-import-conventions
"PIE", # flake8-pie
]
preview = true
ignore = [
Expand Down
10 changes: 4 additions & 6 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import iterative_ensemble_smoother as ies
import numpy as np
import polars
import polars as pl
import psutil
import scipy
from iterative_ensemble_smoother.experimental import AdaptiveESMDA
Expand Down Expand Up @@ -222,14 +222,12 @@ def _load_observations_and_responses(
scaling[obs_group_mask] *= scaling_factors

scaling_factors_dfs.append(
polars.DataFrame(
pl.DataFrame(
{
"input_group": [", ".join(input_group)] * len(scaling_factors),
"index": indexes[obs_group_mask],
"obs_key": obs_keys[obs_group_mask],
"scaling_factor": polars.Series(
scaling_factors, dtype=polars.Float32
),
"scaling_factor": pl.Series(scaling_factors, dtype=pl.Float32),
}
)
)
Expand Down Expand Up @@ -259,7 +257,7 @@ def _load_observations_and_responses(
)

if scaling_factors_dfs:
scaling_factors_df = polars.concat(scaling_factors_dfs)
scaling_factors_df = pl.concat(scaling_factors_dfs)
ensemble.save_observation_scaling_factors(scaling_factors_df)

# Recompute with updated scales
Expand Down
1 change: 0 additions & 1 deletion src/ert/analysis/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

class AnalysisEvent(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
pass


class AnalysisStatusEvent(AnalysisEvent):
Expand Down
4 changes: 2 additions & 2 deletions src/ert/config/ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
overload,
)

import polars
import polars as pl
from pydantic import ValidationError as PydanticValidationError
from pydantic import field_validator
from pydantic.dataclasses import dataclass, rebuild_dataclass
Expand Down Expand Up @@ -298,7 +298,7 @@ def __post_init__(self) -> None:
if self.user_config_file
else os.getcwd()
)
self.observations: dict[str, polars.DataFrame] = self.enkf_obs.datasets
self.observations: dict[str, pl.DataFrame] = self.enkf_obs.datasets

@staticmethod
def with_plugins(
Expand Down
22 changes: 11 additions & 11 deletions src/ert/config/gen_data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Self

import numpy as np
import polars
import polars as pl

from ert.substitutions import substitute_runpath_name
from ert.validation import rangestring_to_list
Expand Down Expand Up @@ -116,8 +116,8 @@ def from_config_dict(cls, config_dict: ConfigDict) -> Self | None:
report_steps_list=report_steps,
)

def read_from_file(self, run_path: str, iens: int, iter: int) -> polars.DataFrame:
def _read_file(filename: Path, report_step: int) -> polars.DataFrame:
def read_from_file(self, run_path: str, iens: int, iter: int) -> pl.DataFrame:
def _read_file(filename: Path, report_step: int) -> pl.DataFrame:
try:
data = np.loadtxt(filename, ndmin=1)
except ValueError as err:
Expand All @@ -129,13 +129,13 @@ def _read_file(filename: Path, report_step: int) -> polars.DataFrame:
except ValueError as err:
raise InvalidResponseFile(str(err)) from err
data[active_list == 0] = np.nan
return polars.DataFrame(
return pl.DataFrame(
{
"report_step": polars.Series(
np.full(len(data), report_step), dtype=polars.UInt16
"report_step": pl.Series(
np.full(len(data), report_step), dtype=pl.UInt16
),
"index": polars.Series(np.arange(len(data)), dtype=polars.UInt16),
"values": polars.Series(data, dtype=polars.Float32),
"index": pl.Series(np.arange(len(data)), dtype=pl.UInt16),
"values": pl.Series(data, dtype=pl.Float32),
}
)

Expand Down Expand Up @@ -167,9 +167,9 @@ def _read_file(filename: Path, report_step: int) -> polars.DataFrame:
errors.append(err)

if len(datasets_per_report_step) > 0:
ds_all_report_steps = polars.concat(datasets_per_report_step)
ds_all_report_steps = pl.concat(datasets_per_report_step)
ds_all_report_steps.insert_column(
0, polars.Series("response_key", [name] * len(ds_all_report_steps))
0, pl.Series("response_key", [name] * len(ds_all_report_steps))
)
datasets_per_name.append(ds_all_report_steps)

Expand All @@ -185,7 +185,7 @@ def _read_file(filename: Path, report_step: int) -> polars.DataFrame:
f"{self.name}, errors: {','.join([str(err) for err in errors])}"
)

combined = polars.concat(datasets_per_name)
combined = pl.concat(datasets_per_name)
return combined

def get_args_for_key(self, key: str) -> tuple[str | None, list[int] | None]:
Expand Down
28 changes: 13 additions & 15 deletions src/ert/config/observation_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
if TYPE_CHECKING:
from datetime import datetime

import polars
import polars as pl


@dataclass
Expand All @@ -30,7 +30,7 @@ def __iter__(self) -> Iterable[SummaryObservation | GenObservation]:
def __len__(self) -> int:
return len(self.observations)

def to_dataset(self, active_list: list[int]) -> polars.DataFrame:
def to_dataset(self, active_list: list[int]) -> pl.DataFrame:
if self.observation_type == ObservationType.GENERAL:
dataframes = []
for time_step, node in self.observations.items():
Expand All @@ -39,24 +39,22 @@ def to_dataset(self, active_list: list[int]) -> polars.DataFrame:

assert isinstance(node, GenObservation)
dataframes.append(
polars.DataFrame(
pl.DataFrame(
{
"response_key": self.data_key,
"observation_key": self.observation_key,
"report_step": polars.Series(
"report_step": pl.Series(
np.full(len(node.indices), time_step),
dtype=polars.UInt16,
dtype=pl.UInt16,
),
"index": polars.Series(node.indices, dtype=polars.UInt16),
"observations": polars.Series(
node.values, dtype=polars.Float32
),
"std": polars.Series(node.stds, dtype=polars.Float32),
"index": pl.Series(node.indices, dtype=pl.UInt16),
"observations": pl.Series(node.values, dtype=pl.Float32),
"std": pl.Series(node.stds, dtype=pl.Float32),
}
)
)

combined = polars.concat(dataframes)
combined = pl.concat(dataframes)
return combined
elif self.observation_type == ObservationType.SUMMARY:
observations = []
Expand All @@ -74,15 +72,15 @@ def to_dataset(self, active_list: list[int]) -> polars.DataFrame:
observations.append(n.value)
errors.append(n.std)

dates_series = polars.Series(dates).dt.cast_time_unit("ms")
dates_series = pl.Series(dates).dt.cast_time_unit("ms")

return polars.DataFrame(
return pl.DataFrame(
{
"response_key": actual_response_key,
"observation_key": actual_observation_keys,
"time": dates_series,
"observations": polars.Series(observations, dtype=polars.Float32),
"std": polars.Series(errors, dtype=polars.Float32),
"observations": pl.Series(observations, dtype=pl.Float32),
"std": pl.Series(errors, dtype=pl.Float32),
}
)
else:
Expand Down
8 changes: 4 additions & 4 deletions src/ert/config/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import TYPE_CHECKING

import numpy as np
import polars
import polars as pl

from ert.validation import rangestring_to_list

Expand Down Expand Up @@ -43,7 +43,7 @@ class EnkfObs:
obs_time: list[datetime] = field(default_factory=list)

def __post_init__(self) -> None:
grouped: dict[str, list[polars.DataFrame]] = {}
grouped: dict[str, list[pl.DataFrame]] = {}
for vec in self.obs_vectors.values():
if vec.observation_type == ObservationType.SUMMARY:
if "summary" not in grouped:
Expand All @@ -57,12 +57,12 @@ def __post_init__(self) -> None:

grouped["gen_data"].append(vec.to_dataset([]))

datasets: dict[str, polars.DataFrame] = {}
datasets: dict[str, pl.DataFrame] = {}

for name, dfs in grouped.items():
non_empty_dfs = [df for df in dfs if not df.is_empty()]
if len(non_empty_dfs) > 0:
datasets[name] = polars.concat(non_empty_dfs).sort("observation_key")
datasets[name] = pl.concat(non_empty_dfs).sort("observation_key")

self.datasets = datasets

Expand Down
4 changes: 2 additions & 2 deletions src/ert/config/response_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from typing import Any, Self

import polars
import polars as pl

from ert.config.parameter_config import CustomDict
from ert.config.parsing import ConfigDict
Expand All @@ -23,7 +23,7 @@ class ResponseConfig(ABC):
has_finalized_keys: bool = False

@abstractmethod
def read_from_file(self, run_path: str, iens: int, iter: int) -> polars.DataFrame:
def read_from_file(self, run_path: str, iens: int, iter: int) -> pl.DataFrame:
"""Reads the data for the response from run_path.

Raises:
Expand Down
10 changes: 5 additions & 5 deletions src/ert/config/summary_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .responses_index import responses_index

logger = logging.getLogger(__name__)
import polars
import polars as pl


@dataclass
Expand All @@ -36,7 +36,7 @@ def expected_input_files(self) -> list[str]:
base = self.input_files[0]
return [f"{base}.UNSMRY", f"{base}.SMSPEC"]

def read_from_file(self, run_path: str, iens: int, iter: int) -> polars.DataFrame:
def read_from_file(self, run_path: str, iens: int, iter: int) -> pl.DataFrame:
filename = substitute_runpath_name(self.input_files[0], iens, iter)
_, keys, time_map, data = read_summary(f"{run_path}/{filename}", self.keys)
if len(data) == 0 or len(keys) == 0:
Expand All @@ -49,12 +49,12 @@ def read_from_file(self, run_path: str, iens: int, iter: int) -> polars.DataFram

# Important: Pick lowest unit resolution to allow for using
# datetimes many years into the future
time_map_series = polars.Series(time_map).dt.cast_time_unit("ms")
df = polars.DataFrame(
time_map_series = pl.Series(time_map).dt.cast_time_unit("ms")
df = pl.DataFrame(
{
"response_key": keys,
"time": [time_map_series for _ in data],
"values": [polars.Series(row, dtype=polars.Float32) for row in data],
"values": [pl.Series(row, dtype=pl.Float32) for row in data],
}
)
df = df.explode("values", "time")
Expand Down
16 changes: 7 additions & 9 deletions src/ert/dark_storage/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np
import pandas as pd
import polars
import polars as pl
import xarray as xr
from polars.exceptions import ColumnNotFoundError

Expand Down Expand Up @@ -165,7 +165,7 @@ def data_for_key(
return pd.DataFrame()

try:
vals = data.filter(polars.col("report_step").eq(report_step))
vals = data.filter(pl.col("report_step").eq(report_step))
pivoted = vals.drop("response_key", "report_step").pivot(
on="index", values="values"
)
Expand Down Expand Up @@ -227,7 +227,7 @@ def _get_observations(

for response_type, df in experiment.observations.items():
if observation_keys is not None:
df = df.filter(polars.col("observation_key").is_in(observation_keys))
df = df.filter(pl.col("observation_key").is_in(observation_keys))

if df.is_empty():
continue
Expand All @@ -240,9 +240,7 @@ def _get_observations(
"observations": "values",
}
)
df = df.with_columns(
polars.Series(name="x_axis", values=df.map_rows(x_axis_fn))
)
df = df.with_columns(pl.Series(name="x_axis", values=df.map_rows(x_axis_fn)))
df = df.sort("x_axis")

for obs_key, _obs_df in df.group_by("name"):
Expand Down Expand Up @@ -283,8 +281,8 @@ def get_observation_keys_for_response(
if "gen_data" in ensemble.experiment.observations:
observations = ensemble.experiment.observations["gen_data"]
filtered = observations.filter(
polars.col("response_key").eq(response_key)
& polars.col("report_step").eq(report_step)
pl.col("response_key").eq(response_key)
& pl.col("report_step").eq(report_step)
)

if filtered.is_empty():
Expand All @@ -302,7 +300,7 @@ def get_observation_keys_for_response(

if "summary" in ensemble.experiment.observations:
observations = ensemble.experiment.observations["summary"]
filtered = observations.filter(polars.col("response_key").eq(response_key))
filtered = observations.filter(pl.col("response_key").eq(response_key))

if filtered.is_empty():
return []
Expand Down
8 changes: 4 additions & 4 deletions src/ert/data/_measured_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import numpy as np
import pandas as pd
import polars
import polars as pl

if TYPE_CHECKING:
from ert.storage import Ensemble
Expand Down Expand Up @@ -106,7 +106,7 @@ def _get_data(
response_cls,
) in ensemble.experiment.response_configuration.items():
observations_for_type = observations_by_type[response_type].filter(
polars.col("observation_key").is_in(observation_keys)
pl.col("observation_key").is_in(observation_keys)
)
responses_for_type = ensemble.load_responses(
response_type,
Expand Down Expand Up @@ -142,7 +142,7 @@ def _get_data(
)

joined = joined.sort(by="observation_key").with_columns(
polars.concat_str(response_cls.primary_key, separator=", ").alias(
pl.concat_str(response_cls.primary_key, separator=", ").alias(
"key_index"
)
)
Expand All @@ -154,7 +154,7 @@ def _get_data(
if not joined.is_empty():
dfs.append(joined)

df = polars.concat(dfs)
df = pl.concat(dfs)
df = df.rename(
{
"observations": "OBS",
Expand Down
Loading
Loading