From e5e286c57bd5ccc25528629f22ffc240b9ca0b45 Mon Sep 17 00:00:00 2001 From: NotBioWaste905 <59259188+NotBioWaste905@users.noreply.github.com> Date: Thu, 7 Nov 2024 15:55:11 +0300 Subject: [PATCH] Added partially extracted slots support for the GroupSlots (#394) # Description Added flag `allow_partial_extraction` to the `GroupSlot` class constructor. If `True`, group slot only saves successfully extracted sub-slots. --------- Co-authored-by: Roman Zlobin --- chatsky/slots/slots.py | 30 ++- tests/slots/test_slot_partial_extraction.py | 84 +++++++ tutorials/slots/2_partial_extraction.py | 236 ++++++++++++++++++++ 3 files changed, 341 insertions(+), 9 deletions(-) create mode 100644 tests/slots/test_slot_partial_extraction.py create mode 100644 tutorials/slots/2_partial_extraction.py diff --git a/chatsky/slots/slots.py b/chatsky/slots/slots.py index 276a28f56..3cadd9205 100644 --- a/chatsky/slots/slots.py +++ b/chatsky/slots/slots.py @@ -69,12 +69,17 @@ def two_arg_getattr(__o, name): def recursive_setattr(obj, slot_name: SlotName, value): - parent_slot, _, slot = slot_name.rpartition(".") + parent_slot, sep, slot = slot_name.rpartition(".") - if parent_slot: - setattr(recursive_getattr(obj, parent_slot), slot, value) + if sep == ".": + parent_obj = recursive_getattr(obj, parent_slot) else: - setattr(obj, slot, value) + parent_obj = obj + + if isinstance(value, ExtractedGroupSlot): + getattr(parent_obj, slot).update(value) + else: + setattr(parent_obj, slot, value) class SlotNotExtracted(Exception): @@ -261,9 +266,11 @@ class GroupSlot(BaseSlot, extra="allow", frozen=True): """ __pydantic_extra__: Dict[str, Annotated[Union["GroupSlot", "ValueSlot"], Field(union_mode="left_to_right")]] + allow_partial_extraction: bool = False + """If True, extraction returns only successfully extracted child slots.""" - def __init__(self, **kwargs): # supress unexpected argument warnings - super().__init__(**kwargs) + def __init__(self, allow_partial_extraction=False, **kwargs): + super().__init__(allow_partial_extraction=allow_partial_extraction, **kwargs) @model_validator(mode="after") def __check_extra_field_names__(self): @@ -279,9 +286,12 @@ def __check_extra_field_names__(self): async def get_value(self, ctx: Context) -> ExtractedGroupSlot: child_values = await asyncio.gather(*(child.get_value(ctx) for child in self.__pydantic_extra__.values())) - return ExtractedGroupSlot( - **{child_name: child_value for child_value, child_name in zip(child_values, self.__pydantic_extra__.keys())} - ) + extracted_values = {} + for child_value, child_name in zip(child_values, self.__pydantic_extra__.keys()): + if child_value.__slot_extracted__ or not self.allow_partial_extraction: + extracted_values[child_name] = child_value + + return ExtractedGroupSlot(**extracted_values) def init_value(self) -> ExtractedGroupSlot: return ExtractedGroupSlot( @@ -368,6 +378,8 @@ async def extract_slot(self, slot_name: SlotName, ctx: Context, success_only: bo """ Extract slot `slot_name` and store extracted value in `slot_storage`. + Extracted group slots update slot storage instead of overwriting it. + :raises KeyError: If the slot with the specified name does not exist. :param slot_name: Name of the slot to extract. diff --git a/tests/slots/test_slot_partial_extraction.py b/tests/slots/test_slot_partial_extraction.py new file mode 100644 index 000000000..234287c17 --- /dev/null +++ b/tests/slots/test_slot_partial_extraction.py @@ -0,0 +1,84 @@ +from chatsky.slots import RegexpSlot, GroupSlot +from chatsky.slots.slots import SlotManager +from chatsky.core import Message + +import pytest + +test_slot = GroupSlot( + root_slot=GroupSlot( + one=RegexpSlot(regexp=r"1"), + two=RegexpSlot(regexp=r"2"), + nested_group=GroupSlot( + three=RegexpSlot(regexp=r"3"), + four=RegexpSlot(regexp=r"4"), + allow_partial_extraction=False, + ), + nested_partial_group=GroupSlot( + five=RegexpSlot(regexp=r"5"), + six=RegexpSlot(regexp=r"6"), + allow_partial_extraction=True, + ), + allow_partial_extraction=True, + ) +) + +extracted_slots = { + "root_slot.one": "1", + "root_slot.two": "2", + "root_slot.nested_group.three": "3", + "root_slot.nested_group.four": "4", + "root_slot.nested_partial_group.five": "5", + "root_slot.nested_partial_group.six": "6", +} + + +@pytest.fixture(scope="function") +def context_with_request(context): + def inner(request): + context.add_request(Message(request)) + return context + + return inner + + +@pytest.fixture(scope="function") +def empty_slot_manager(): + manager = SlotManager() + manager.set_root_slot(test_slot) + return manager + + +def get_extracted_slots(manager: SlotManager): + values = [] + for slot, value in extracted_slots.items(): + extracted_value = manager.get_extracted_slot(slot) + if extracted_value.__slot_extracted__: + if extracted_value.value == value: + values.append(value) + else: + raise RuntimeError(f"Extracted value {extracted_value} does not match expected {value}.") + return values + + +@pytest.mark.parametrize( + "message,extracted", + [("1 2 3", ["1", "2"]), ("1 3 5", ["1", "5"]), ("3 4 5 6", ["3", "4", "5", "6"])], +) +async def test_partial_extraction(message, extracted, context_with_request, empty_slot_manager): + await empty_slot_manager.extract_slot("root_slot", context_with_request(message), success_only=False) + + assert extracted == get_extracted_slots(empty_slot_manager) + + +async def test_slot_storage_update(context_with_request, empty_slot_manager): + await empty_slot_manager.extract_slot("root_slot", context_with_request("1 3 5"), success_only=False) + + assert get_extracted_slots(empty_slot_manager) == ["1", "5"] + + await empty_slot_manager.extract_slot("root_slot", context_with_request("2 4 6"), success_only=False) + + assert get_extracted_slots(empty_slot_manager) == ["1", "2", "5", "6"] + + await empty_slot_manager.extract_slot("root_slot.nested_group", context_with_request("3 4"), success_only=False) + + assert get_extracted_slots(empty_slot_manager) == ["1", "2", "3", "4", "5", "6"] diff --git a/tutorials/slots/2_partial_extraction.py b/tutorials/slots/2_partial_extraction.py new file mode 100644 index 000000000..87cc4bab4 --- /dev/null +++ b/tutorials/slots/2_partial_extraction.py @@ -0,0 +1,236 @@ +# %% [markdown] +""" +# 2. Partial slot extraction + +This tutorial shows advanced options for slot extraction allowing +to extract only some of the slots. +""" + +# %pip install chatsky + +# %% +from chatsky import ( + RESPONSE, + TRANSITIONS, + PRE_RESPONSE, + GLOBAL, + Pipeline, + Transition as Tr, + conditions as cnd, + processing as proc, + responses as rsp, +) + +from chatsky.slots import RegexpSlot, GroupSlot + +from chatsky.utils.testing import ( + check_happy_path, + is_interactive_mode, +) + +# %% [markdown] +""" +## Extracted values + +Result of successful slot extraction is the extracted value, *but* +if the extraction fails, the slot will be marked as "not extracted" +and its value will be set to the slot's `default_value` (`None` by default). + +If group slot is being extracted, the extraction is considered successful +if and only if all child slots are successfully extracted. + +## Success only extraction + +The `Extract` function accepts `success_only` flag which makes it so +that extracted value is not saved unless extraction is successful. + +This means that unsuccessfully trying to extract a slot will not overwrite +its previously extracted value. + +Note that `success_only` is `True` by default. + +## Partial group slot extraction + +A group slot marked with `allow_partial_extraction` only saves values +of successfully extracted child slots. +Extracting such group slot is equivalent to extracting every child slot +with the `success_only` flag. + +Partially extracted group slot is always considered successfully extracted +for the purposes of the `success_only` flag. + +## Code explanation + +In this example we showcase the behavior of +different group slot extraction settings: + +Group `partial_extraction` is marked with `allow_partial_extraction`. +Any slot in this group is saved if and only if that slot is successfully +extracted. + +Group `success_only_extraction` is extracted with the `success_only` +flag set to True. +Any slot in this group is saved if and only if all of the slots in the group +are successfully extracted within a single `Extract` call. + +Group `success_only_false` is extracted with the `success_only` set to False. +Every slot in this group is saved (even if extraction was not successful). + +Group `sub_slot_success_only_extraction` is extracted by passing all of its +child slots to the `Extract` method with the `success_only` flag set to True. +The behavior is equivalent to that of `partial_extraction`. +""" + +# %% +sub_slots = { + "date": RegexpSlot( + regexp=r"(0?[1-9]|(?:1|2)[0-9]|3[0-1])[\.\/]" + r"(0?[1-9]|1[0-2])[\.\/](\d{4}|\d{2})", + ), + "email": RegexpSlot( + regexp=r"[\w\.-]+@[\w\.-]+\.\w{2,4}", + ), +} + +SLOTS = { + "partial_extraction": GroupSlot( + **sub_slots, + allow_partial_extraction=True, + ), + "success_only_extraction": GroupSlot( + **sub_slots, + ), + "success_only_false": GroupSlot( + **sub_slots, + ), + "sub_slot_success_only_extraction": GroupSlot( + **sub_slots, + ), +} + +script = { + GLOBAL: { + TRANSITIONS: [ + Tr(dst=("main", "start"), cnd=cnd.ExactMatch("/start")), + Tr(dst=("main", "reset"), cnd=cnd.ExactMatch("/reset")), + Tr(dst=("main", "print"), priority=0.5), + ] + }, + "main": { + "start": {RESPONSE: "Hi! Send me email and date."}, + "reset": { + PRE_RESPONSE: {"reset_slots": proc.UnsetAll()}, + RESPONSE: "All slots have been reset.", + }, + "print": { + PRE_RESPONSE: { + "partial_extraction": proc.Extract("partial_extraction"), + # partial extraction is always successful; + # success_only doesn't matter + "success_only_extraction": proc.Extract( + "success_only_extraction", success_only=True + ), + # success_only is True by default + "success_only_false": proc.Extract( + "success_only_false", success_only=False + ), + "sub_slot_success_only_extraction": proc.Extract( + "sub_slot_success_only_extraction.email", + "sub_slot_success_only_extraction.date", + success_only=True, + ), + }, + RESPONSE: rsp.FilledTemplate( + "Extracted slots:\n" + " Group with partial extraction:\n" + " {partial_extraction}\n" + " Group with success_only:\n" + " {success_only_extraction}\n" + " Group without success_only:\n" + " {success_only_false}\n" + " Extracting sub-slots with success_only:\n" + " {sub_slot_success_only_extraction}" + ), + }, + }, +} + +HAPPY_PATH = [ + ("/start", "Hi! Send me email and date."), + ( + "Only email: email@email.com", + "Extracted slots:\n" + " Group with partial extraction:\n" + " {'date': 'None', 'email': 'email@email.com'}\n" + " Group with success_only:\n" + " {'date': 'None', 'email': 'None'}\n" + " Group without success_only:\n" + " {'date': 'None', 'email': 'email@email.com'}\n" + " Extracting sub-slots with success_only:\n" + " {'date': 'None', 'email': 'email@email.com'}", + ), + ( + "Only date: 01.01.2024", + "Extracted slots:\n" + " Group with partial extraction:\n" + " {'date': '01.01.2024', 'email': 'email@email.com'}\n" + " Group with success_only:\n" + " {'date': 'None', 'email': 'None'}\n" + " Group without success_only:\n" + " {'date': '01.01.2024', 'email': 'None'}\n" + " Extracting sub-slots with success_only:\n" + " {'date': '01.01.2024', 'email': 'email@email.com'}", + ), + ( + "Both email and date: another_email@email.com; 02.01.2024", + "Extracted slots:\n" + " Group with partial extraction:\n" + " {'date': '02.01.2024', 'email': 'another_email@email.com'}\n" + " Group with success_only:\n" + " {'date': '02.01.2024', 'email': 'another_email@email.com'}\n" + " Group without success_only:\n" + " {'date': '02.01.2024', 'email': 'another_email@email.com'}\n" + " Extracting sub-slots with success_only:\n" + " {'date': '02.01.2024', 'email': 'another_email@email.com'}", + ), + ( + "Partial update (date only): 03.01.2024", + "Extracted slots:\n" + " Group with partial extraction:\n" + " {'date': '03.01.2024', 'email': 'another_email@email.com'}\n" + " Group with success_only:\n" + " {'date': '02.01.2024', 'email': 'another_email@email.com'}\n" + " Group without success_only:\n" + " {'date': '03.01.2024', 'email': 'None'}\n" + " Extracting sub-slots with success_only:\n" + " {'date': '03.01.2024', 'email': 'another_email@email.com'}", + ), + ( + "No slots here but `Extract` will still be called.", + "Extracted slots:\n" + " Group with partial extraction:\n" + " {'date': '03.01.2024', 'email': 'another_email@email.com'}\n" + " Group with success_only:\n" + " {'date': '02.01.2024', 'email': 'another_email@email.com'}\n" + " Group without success_only:\n" + " {'date': 'None', 'email': 'None'}\n" + " Extracting sub-slots with success_only:\n" + " {'date': '03.01.2024', 'email': 'another_email@email.com'}", + ), +] + + +# %% +pipeline = Pipeline( + script=script, + start_label=("main", "start"), + slots=SLOTS, +) + + +# %% +if __name__ == "__main__": + check_happy_path(pipeline, HAPPY_PATH, printout=True) + + if is_interactive_mode(): + pipeline.run()