9
9
10
10
import numpy as np
11
11
import pandas as pd
12
+ from numpy .typing import NDArray
12
13
13
14
14
15
def get_cols (df , cols ):
@@ -124,7 +125,7 @@ def sizes_to_indices(sizes):
124
125
125
126
Returns
126
127
-------
127
- list[np.ndarray ]
128
+ list[NDArray ]
128
129
list the indices.
129
130
130
131
"""
@@ -325,29 +326,29 @@ def avg_integral(mat, spline=None, use_spline_intercept=False):
325
326
# random knots
326
327
def sample_knots (
327
328
num_knots : int ,
328
- knot_bounds : np . ndarray ,
329
- min_dist : float | np . ndarray ,
329
+ knot_bounds : NDArray ,
330
+ min_dist : float | NDArray ,
330
331
num_samples : int = 1 ,
331
- ) -> np . ndarray :
332
+ ) -> NDArray :
332
333
"""Sample knot vectors given a set of rules.
333
334
334
335
Parameters
335
336
----------
336
337
num_knots : int
337
338
Number of interior knots.
338
- knot_bounds : np.ndarray , shape(2,) or shape(`num_knots`,2)
339
+ knot_bounds : NDArray , shape(2,) or shape(`num_knots`,2)
339
340
Lower and upper bounds for knots. If shape(2,), boundary knots
340
341
placed at `knot_bounds[0]` and `knot_bounds[1]`. If
341
342
shape(`num_knots`,2), boundary knots placed at
342
343
`knot_bounds[0, 0]` and `knot_bounds[-1, 1]`.
343
- min_dist : float or np.ndarray , shape(`num_knots`+1,)
344
+ min_dist : float or NDArray , shape(`num_knots`+1,)
344
345
Minimum distances between knots.
345
346
num_samples : int, optional
346
347
Number of knot vectors to sample. Default is 1.
347
348
348
349
Returns
349
350
-------
350
- np.ndarray , shape(`num_samples`,`num_knots`+2)
351
+ NDArray , shape(`num_samples`,`num_knots`+2)
351
352
Sampled knot vectors.
352
353
353
354
"""
@@ -380,7 +381,7 @@ def _check_nums(num_name: str, num_val: int) -> None:
380
381
raise ValueError (f"{ num_name } must be at least 1" )
381
382
382
383
383
- def _check_knot_bounds (num_knots : int , knot_bounds : np . ndarray ) -> np . ndarray :
384
+ def _check_knot_bounds (num_knots : int , knot_bounds : NDArray ) -> NDArray :
384
385
"""Check knot_bounds."""
385
386
try :
386
387
knot_bounds = np .asarray (knot_bounds , dtype = float )
@@ -399,7 +400,7 @@ def _check_knot_bounds(num_knots: int, knot_bounds: np.ndarray) -> np.ndarray:
399
400
return knot_bounds
400
401
401
402
402
- def _check_min_dist (num_knots : int , min_dist : float | np . ndarray ) -> np . ndarray :
403
+ def _check_min_dist (num_knots : int , min_dist : float | NDArray ) -> NDArray :
403
404
"""Check knot min_dist."""
404
405
if np .isscalar (min_dist ):
405
406
min_dist = np .tile (min_dist , num_knots + 1 )
@@ -415,8 +416,8 @@ def _check_min_dist(num_knots: int, min_dist: float | np.ndarray) -> np.ndarray:
415
416
416
417
417
418
def _check_feasibility (
418
- num_knots : int , knot_bounds : np . ndarray , min_dist : np . ndarray
419
- ) -> tuple [np . ndarray , np . ndarray ]:
419
+ num_knots : int , knot_bounds : NDArray , min_dist : NDArray
420
+ ) -> tuple [NDArray , NDArray ]:
420
421
"""Check knot feasibility and get left and right boundaries."""
421
422
if np .sum (min_dist ) > knot_bounds [- 1 , 1 ] - knot_bounds [0 , 0 ]:
422
423
raise ValueError ("min_dist cannot exceed knot_bounds" )
@@ -561,7 +562,7 @@ def to_list(obj: Any) -> list[Any]:
561
562
return [obj ]
562
563
563
564
564
- def is_numeric_array (array : np . ndarray ) -> bool :
565
+ def is_numeric_array (array : NDArray ) -> bool :
565
566
"""Check if an array is numeric.
566
567
567
568
Parameters
@@ -590,8 +591,8 @@ def is_numeric_array(array: np.ndarray) -> bool:
590
591
591
592
592
593
def expand_array (
593
- array : np . ndarray , shape : tuple [int ], value : Any , name : str
594
- ) -> np . ndarray :
594
+ array : NDArray , shape : tuple [int ], value : Any , name : str
595
+ ) -> NDArray :
595
596
"""Expand array when it is empty.
596
597
597
598
Parameters
@@ -608,7 +609,7 @@ def expand_array(
608
609
609
610
Returns
610
611
-------
611
- np.ndarray
612
+ NDArray
612
613
Expanded array.
613
614
614
615
"""
@@ -630,7 +631,7 @@ def expand_array(
630
631
def ravel_dict (x : dict ) -> dict :
631
632
"""Ravel dictionary."""
632
633
assert all ([isinstance (k , str ) for k in x .keys ()])
633
- assert all ([isinstance (v , np . ndarray ) for v in x .values ()])
634
+ assert all ([isinstance (v , NDArray ) for v in x .values ()])
634
635
new_x = {}
635
636
for k , v in x .items ():
636
637
if v .size == 1 :
0 commit comments