diff --git a/examples/expanders.py b/examples/expanders.py new file mode 100644 index 0000000..939168c --- /dev/null +++ b/examples/expanders.py @@ -0,0 +1,30 @@ +""" +An example of using the Expander annotation with nested objects. +""" + +from typing_extensions import Annotated, List +from pydantic import BaseModel +import streamlit as st +import streamlit_pydantic as sp +from streamlit_pydantic import Expander + + +class Child(BaseModel): + """Child class.""" + + name: str + age: int + + +class Parent(BaseModel): + """Parent class.""" + + occupation: str + child: Annotated[List[Child], Expander] + + +data = sp.pydantic_input("form", model=Parent) + +if data: + obj = Parent.model_validate(data) + st.json(obj.model_dump()) diff --git a/src/streamlit_pydantic/__init__.py b/src/streamlit_pydantic/__init__.py index fd80597..3fb0420 100644 --- a/src/streamlit_pydantic/__init__.py +++ b/src/streamlit_pydantic/__init__.py @@ -16,3 +16,5 @@ from .ui_renderer import pydantic_output as _pydantic_output pydantic_output = st._gather_metrics("pydantic_output", _pydantic_output) + +from .annotations import Expander diff --git a/src/streamlit_pydantic/annotations.py b/src/streamlit_pydantic/annotations.py new file mode 100644 index 0000000..6191306 --- /dev/null +++ b/src/streamlit_pydantic/annotations.py @@ -0,0 +1,8 @@ +""" +Annotations to be used within Pydantic objects. +""" + +from typing import Annotated, TypeVar + + +Expander = TypeVar("Expander") diff --git a/src/streamlit_pydantic/ui_renderer.py b/src/streamlit_pydantic/ui_renderer.py index 989a7e6..1419516 100644 --- a/src/streamlit_pydantic/ui_renderer.py +++ b/src/streamlit_pydantic/ui_renderer.py @@ -17,6 +17,9 @@ from streamlit_pydantic import schema_utils +from .annotations import Expander + + _OVERWRITE_STREAMLIT_KWARGS_PREFIX = "st_kwargs_" @@ -168,6 +171,10 @@ def render_ui(self) -> Dict: property = self._schema_properties[property_key] + expander = Expander in self._input_class.model_fields[property_key].metadata + + property["expander"] = expander + if not property.get("title"): # Set property key as fallback title property["title"] = _name_to_title(property_key) @@ -185,10 +192,25 @@ def render_ui(self) -> Dict: if attr is not None: property["instance_class"] = str(type(attr)) - try: - value = self._render_property(streamlit_app, property_key, property) + def _render_prop(): + value = self._render_property( + streamlit_app, property_key, property + ) if not self._is_value_ignored(property_key, value): self._store_value(property_key, value) + + try: + if expander and not schema_utils.is_object_list_property( + property, self._schema_references + ): + with self._streamlit_container.expander( + property_key, expanded=False + ): + _render_prop() + + else: + _render_prop() + except Exception: pass @@ -1012,7 +1034,9 @@ def _render_dict_clear_button( return data_dict - def _render_list_input(self, streamlit_app: Any, key: str, property: Dict) -> Any: + def _render_list_input( + self, streamlit_app: Any, key: str, property: Dict + ) -> Any: # Add title and subheader streamlit_app.subheader(property.get("title")) if property.get("description"): @@ -1041,20 +1065,35 @@ def _render_list_input(self, streamlit_app: Any, key: str, property: Dict) -> An if self._clear_button_allowed(property): data_list = self._render_list_clear_button(key, clear_col, data_list) + remove_inds = [] + + def _render_item(index, item): + output = self._render_list_item( + streamlit_app, + key, + item, + index, + property, + ) + if output is not None: + object_list.append(output) + + else: + remove_inds.append(index) + + if is_object: + streamlit_app.markdown("---") + if len(data_list) > 0: for index, item in enumerate(data_list): - output = self._render_list_item( - streamlit_app, - key, - item, - index, - property, - ) - if output is not None: - object_list.append(output) + if "expander" in property and property["expander"]: + with self._streamlit_container.expander( + property["title"], expanded=False + ): + _render_item(index, item) - if is_object: - streamlit_app.markdown("---") + else: + _render_item(index, item) if not self._add_button_allowed(len(object_list), property): add_col = add_col.empty() @@ -1062,6 +1101,12 @@ def _render_list_input(self, streamlit_app: Any, key: str, property: Dict) -> An if not is_object: streamlit_app.markdown("---") + for ind in reversed(remove_inds): + data_list.pop(ind) + + if len(remove_inds) > 0: + st.rerun() + return object_list def _render_property(self, streamlit_app: Any, key: str, property: Dict) -> Any: