diff --git a/pyrfd/plots.py b/pyrfd/plots.py
index 534dfa8..5fa544f 100644
--- a/pyrfd/plots.py
+++ b/pyrfd/plots.py
@@ -6,6 +6,16 @@
 
 from .regression import ScalarRegression
 
+def selection(sorted_list, num_elts):
+    """
+    return a selection of num_elts from the sorted_list (evenly spaced in the index)
+    always includes the first and last index
+    """
+    if len(sorted_list) < num_elts:
+        return sorted_list
+    idxs = np.round(np.linspace(0, len(sorted_list) - 1, num_elts)).astype(int)
+    return sorted_list[idxs]
+
 
 def plot_loss(ax, df: pd.DataFrame, *, mean, var_reg: ScalarRegression):
     """
diff --git a/pyrfd/regression.py b/pyrfd/regression.py
index 37fc6f5..aafa3e7 100644
--- a/pyrfd/regression.py
+++ b/pyrfd/regression.py
@@ -39,17 +39,6 @@ def intercept(self):
         return self.intercept_
 
 
-def selection(sorted_list, num_elts):
-    """
-    return a selection of num_elts from the sorted_list (evenly spaced in the index)
-    always includes the first and last index
-    """
-    if len(sorted_list) < num_elts:
-        return sorted_list
-    idxs = np.round(np.linspace(0, len(sorted_list) - 1, num_elts)).astype(int)
-    return sorted_list[idxs]
-
-
 def fit_mean_var(
     batch_sizes: np.array,
     batch_losses: np.array,