diff --git a/prompterator/constants.py b/prompterator/constants.py index 81d0de6..04d8447 100644 --- a/prompterator/constants.py +++ b/prompterator/constants.py @@ -82,6 +82,7 @@ def call(self, input, **kwargs): RESPONSE_DATA_COL: object, LABEL_COL: bool, } +# these are the columns that users won't be able to show or inject into their prompts COLS_NOT_FOR_PROMPT_INTERPOLATION = [ TEXT_GENERATED_COL, SYSTEM_PROMPT_TEMPLATE_COL, @@ -91,7 +92,7 @@ def call(self, input, **kwargs): ] LABEL_GOOD = "good" LABEL_BAD = "bad" -COLS_TO_SHOW = [TEXT_ORIG_COL, TEXT_GENERATED_COL, LABEL_COL] +DUMMY_DATA_COLS = [TEXT_ORIG_COL, TEXT_GENERATED_COL, LABEL_COL] LABEL_VALUE_COLOURS = { LABEL_GOOD: "#56E7AB", LABEL_BAD: "#FE8080", diff --git a/prompterator/main.py b/prompterator/main.py index f36a51a..35b5ab1 100644 --- a/prompterator/main.py +++ b/prompterator/main.py @@ -514,12 +514,12 @@ def set_up_ui_labelling(): def show_col_selection(): if st.session_state.get("df") is not None: - columns_sel = st.session_state.df.columns.tolist() - columns_sel.remove(c.TEXT_GENERATED_COL) - columns_sel.remove(c.LABEL_COL) - columns_sel.remove(c.RESPONSE_DATA_COL) + available_columns = st.session_state.df.columns.tolist() + available_columns = [ + col for col in available_columns if col not in c.COLS_NOT_FOR_PROMPT_INTERPOLATION + ] st.session_state["columns_to_show"] = st.multiselect( - "Columns to show", columns_sel, [c.TEXT_ORIG_COL] + "Columns to show", options=available_columns, default=[c.TEXT_ORIG_COL] ) diff --git a/prompterator/utils.py b/prompterator/utils.py index 64c0afd..5113f87 100644 --- a/prompterator/utils.py +++ b/prompterator/utils.py @@ -282,6 +282,6 @@ def get_dummy_dataframe(): ) # to fail if the list of required columns changes; we'll then update the hard-coded dict above - assert set(df.columns) == set(c.COLS_TO_SHOW) + assert set(df.columns) == set(c.DUMMY_DATA_COLS) return df