Skip to content

Commit

Permalink
Fix mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Jan 23, 2025
1 parent a507e10 commit e00745b
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions src/optimagic/optimization/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,8 @@ def __getitem__(self, key: str) -> Any:


def _get_flat_params(params: list[PyTree]) -> list[list[float]]:
if len(params) > 0 and _is_1d_array(params[0]):
# fast path
fast_path = len(params) > 0 and _is_1d_array(params[0])
if fast_path:
flatten = lambda x: x.tolist()
else:
registry = get_registry(extended=True)
Expand All @@ -327,9 +327,11 @@ def _get_flat_params(params: list[PyTree]) -> list[list[float]]:


def _get_flat_param_names(param: PyTree) -> list[str]:
if _is_1d_array(param):
# fast path
return np.arange(param.size).astype(str).tolist()
fast_path = _is_1d_array(param)
if fast_path:
# Mypy raises an error here because .tolist() returns a str for zero-dimensional
# arrays, but the fast path is only taken for 1d arrays, so it can be ignored.
return np.arange(param.size).astype(str).tolist() # type: ignore[return-value]

registry = get_registry(extended=True)
return leaf_names(param, registry=registry)
Expand Down Expand Up @@ -409,7 +411,7 @@ def _apply_to_batch(
batch_starts = _get_batch_start(batch_ids)
batch_stops = [*batch_starts[1:], len(data)]

batch_results = []
batch_results: list[float] = []
for start, stop in zip(batch_starts, batch_stops, strict=True):
batch_data = data[start:stop]
batch_id = batch_ids[start]
Expand All @@ -432,7 +434,7 @@ def _apply_to_batch(
)
raise ValueError(msg)

batch_results.append(reduced)
batch_results.append(reduced) # type: ignore[arg-type]

out = np.zeros_like(data)
out[batch_starts] = batch_results
Expand All @@ -447,4 +449,5 @@ def _get_batch_start(batch_ids: list[int]) -> list[int]:
"""
ids_arr = np.array(batch_ids, dtype=np.int64)
indices = np.where(ids_arr[:-1] != ids_arr[1:])[0] + 1
return np.insert(indices, 0, 0).tolist()
list_indices: list[int] = indices.tolist() # type: ignore[assignment]
return [0, *list_indices]

0 comments on commit e00745b

Please sign in to comment.