diff --git a/pyrfd/plots.py b/pyrfd/plots.py index 534dfa8..5fa544f 100644 --- a/pyrfd/plots.py +++ b/pyrfd/plots.py @@ -6,6 +6,16 @@ from .regression import ScalarRegression +def selection(sorted_list, num_elts): + """ + return a selection of num_elts from the sorted_list (evenly spaced in the index) + always includes the first and last index + """ + if len(sorted_list) < num_elts: + return sorted_list + idxs = np.round(np.linspace(0, len(sorted_list) - 1, num_elts)).astype(int) + return sorted_list[idxs] + def plot_loss(ax, df: pd.DataFrame, *, mean, var_reg: ScalarRegression): """ diff --git a/pyrfd/regression.py b/pyrfd/regression.py index 37fc6f5..aafa3e7 100644 --- a/pyrfd/regression.py +++ b/pyrfd/regression.py @@ -39,17 +39,6 @@ def intercept(self): return self.intercept_ -def selection(sorted_list, num_elts): - """ - return a selection of num_elts from the sorted_list (evenly spaced in the index) - always includes the first and last index - """ - if len(sorted_list) < num_elts: - return sorted_list - idxs = np.round(np.linspace(0, len(sorted_list) - 1, num_elts)).astype(int) - return sorted_list[idxs] - - def fit_mean_var( batch_sizes: np.array, batch_losses: np.array,