Skip to content

Commit

Permalink
Added partially extracted slots support for the GroupSlots (#394)
Browse files Browse the repository at this point in the history
# Description

Added flag `allow_partial_extraction` to the `GroupSlot` class
constructor. If `True`, group slot only saves successfully extracted sub-slots.

---------

Co-authored-by: Roman Zlobin <[email protected]>
  • Loading branch information
NotBioWaste905 and RLKRo authored Nov 7, 2024
1 parent bd0c535 commit e5e286c
Show file tree
Hide file tree
Showing 3 changed files with 341 additions and 9 deletions.
30 changes: 21 additions & 9 deletions chatsky/slots/slots.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,17 @@ def two_arg_getattr(__o, name):


def recursive_setattr(obj, slot_name: SlotName, value):
parent_slot, _, slot = slot_name.rpartition(".")
parent_slot, sep, slot = slot_name.rpartition(".")

if parent_slot:
setattr(recursive_getattr(obj, parent_slot), slot, value)
if sep == ".":
parent_obj = recursive_getattr(obj, parent_slot)
else:
setattr(obj, slot, value)
parent_obj = obj

if isinstance(value, ExtractedGroupSlot):
getattr(parent_obj, slot).update(value)
else:
setattr(parent_obj, slot, value)


class SlotNotExtracted(Exception):
Expand Down Expand Up @@ -261,9 +266,11 @@ class GroupSlot(BaseSlot, extra="allow", frozen=True):
"""

__pydantic_extra__: Dict[str, Annotated[Union["GroupSlot", "ValueSlot"], Field(union_mode="left_to_right")]]
allow_partial_extraction: bool = False
"""If True, extraction returns only successfully extracted child slots."""

def __init__(self, **kwargs): # supress unexpected argument warnings
super().__init__(**kwargs)
def __init__(self, allow_partial_extraction=False, **kwargs):
super().__init__(allow_partial_extraction=allow_partial_extraction, **kwargs)

@model_validator(mode="after")
def __check_extra_field_names__(self):
Expand All @@ -279,9 +286,12 @@ def __check_extra_field_names__(self):

async def get_value(self, ctx: Context) -> ExtractedGroupSlot:
child_values = await asyncio.gather(*(child.get_value(ctx) for child in self.__pydantic_extra__.values()))
return ExtractedGroupSlot(
**{child_name: child_value for child_value, child_name in zip(child_values, self.__pydantic_extra__.keys())}
)
extracted_values = {}
for child_value, child_name in zip(child_values, self.__pydantic_extra__.keys()):
if child_value.__slot_extracted__ or not self.allow_partial_extraction:
extracted_values[child_name] = child_value

return ExtractedGroupSlot(**extracted_values)

def init_value(self) -> ExtractedGroupSlot:
return ExtractedGroupSlot(
Expand Down Expand Up @@ -368,6 +378,8 @@ async def extract_slot(self, slot_name: SlotName, ctx: Context, success_only: bo
"""
Extract slot `slot_name` and store extracted value in `slot_storage`.
Extracted group slots update slot storage instead of overwriting it.
:raises KeyError: If the slot with the specified name does not exist.
:param slot_name: Name of the slot to extract.
Expand Down
84 changes: 84 additions & 0 deletions tests/slots/test_slot_partial_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from chatsky.slots import RegexpSlot, GroupSlot
from chatsky.slots.slots import SlotManager
from chatsky.core import Message

import pytest

test_slot = GroupSlot(
root_slot=GroupSlot(
one=RegexpSlot(regexp=r"1"),
two=RegexpSlot(regexp=r"2"),
nested_group=GroupSlot(
three=RegexpSlot(regexp=r"3"),
four=RegexpSlot(regexp=r"4"),
allow_partial_extraction=False,
),
nested_partial_group=GroupSlot(
five=RegexpSlot(regexp=r"5"),
six=RegexpSlot(regexp=r"6"),
allow_partial_extraction=True,
),
allow_partial_extraction=True,
)
)

extracted_slots = {
"root_slot.one": "1",
"root_slot.two": "2",
"root_slot.nested_group.three": "3",
"root_slot.nested_group.four": "4",
"root_slot.nested_partial_group.five": "5",
"root_slot.nested_partial_group.six": "6",
}


@pytest.fixture(scope="function")
def context_with_request(context):
def inner(request):
context.add_request(Message(request))
return context

return inner


@pytest.fixture(scope="function")
def empty_slot_manager():
manager = SlotManager()
manager.set_root_slot(test_slot)
return manager


def get_extracted_slots(manager: SlotManager):
values = []
for slot, value in extracted_slots.items():
extracted_value = manager.get_extracted_slot(slot)
if extracted_value.__slot_extracted__:
if extracted_value.value == value:
values.append(value)
else:
raise RuntimeError(f"Extracted value {extracted_value} does not match expected {value}.")
return values


@pytest.mark.parametrize(
"message,extracted",
[("1 2 3", ["1", "2"]), ("1 3 5", ["1", "5"]), ("3 4 5 6", ["3", "4", "5", "6"])],
)
async def test_partial_extraction(message, extracted, context_with_request, empty_slot_manager):
await empty_slot_manager.extract_slot("root_slot", context_with_request(message), success_only=False)

assert extracted == get_extracted_slots(empty_slot_manager)


async def test_slot_storage_update(context_with_request, empty_slot_manager):
await empty_slot_manager.extract_slot("root_slot", context_with_request("1 3 5"), success_only=False)

assert get_extracted_slots(empty_slot_manager) == ["1", "5"]

await empty_slot_manager.extract_slot("root_slot", context_with_request("2 4 6"), success_only=False)

assert get_extracted_slots(empty_slot_manager) == ["1", "2", "5", "6"]

await empty_slot_manager.extract_slot("root_slot.nested_group", context_with_request("3 4"), success_only=False)

assert get_extracted_slots(empty_slot_manager) == ["1", "2", "3", "4", "5", "6"]
236 changes: 236 additions & 0 deletions tutorials/slots/2_partial_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
# %% [markdown]
"""
# 2. Partial slot extraction
This tutorial shows advanced options for slot extraction allowing
to extract only some of the slots.
"""

# %pip install chatsky

# %%
from chatsky import (
RESPONSE,
TRANSITIONS,
PRE_RESPONSE,
GLOBAL,
Pipeline,
Transition as Tr,
conditions as cnd,
processing as proc,
responses as rsp,
)

from chatsky.slots import RegexpSlot, GroupSlot

from chatsky.utils.testing import (
check_happy_path,
is_interactive_mode,
)

# %% [markdown]
"""
## Extracted values
Result of successful slot extraction is the extracted value, *but*
if the extraction fails, the slot will be marked as "not extracted"
and its value will be set to the slot's `default_value` (`None` by default).
If group slot is being extracted, the extraction is considered successful
if and only if all child slots are successfully extracted.
## Success only extraction
The `Extract` function accepts `success_only` flag which makes it so
that extracted value is not saved unless extraction is successful.
This means that unsuccessfully trying to extract a slot will not overwrite
its previously extracted value.
Note that `success_only` is `True` by default.
## Partial group slot extraction
A group slot marked with `allow_partial_extraction` only saves values
of successfully extracted child slots.
Extracting such group slot is equivalent to extracting every child slot
with the `success_only` flag.
Partially extracted group slot is always considered successfully extracted
for the purposes of the `success_only` flag.
## Code explanation
In this example we showcase the behavior of
different group slot extraction settings:
Group `partial_extraction` is marked with `allow_partial_extraction`.
Any slot in this group is saved if and only if that slot is successfully
extracted.
Group `success_only_extraction` is extracted with the `success_only`
flag set to True.
Any slot in this group is saved if and only if all of the slots in the group
are successfully extracted within a single `Extract` call.
Group `success_only_false` is extracted with the `success_only` set to False.
Every slot in this group is saved (even if extraction was not successful).
Group `sub_slot_success_only_extraction` is extracted by passing all of its
child slots to the `Extract` method with the `success_only` flag set to True.
The behavior is equivalent to that of `partial_extraction`.
"""

# %%
sub_slots = {
"date": RegexpSlot(
regexp=r"(0?[1-9]|(?:1|2)[0-9]|3[0-1])[\.\/]"
r"(0?[1-9]|1[0-2])[\.\/](\d{4}|\d{2})",
),
"email": RegexpSlot(
regexp=r"[\w\.-]+@[\w\.-]+\.\w{2,4}",
),
}

SLOTS = {
"partial_extraction": GroupSlot(
**sub_slots,
allow_partial_extraction=True,
),
"success_only_extraction": GroupSlot(
**sub_slots,
),
"success_only_false": GroupSlot(
**sub_slots,
),
"sub_slot_success_only_extraction": GroupSlot(
**sub_slots,
),
}

script = {
GLOBAL: {
TRANSITIONS: [
Tr(dst=("main", "start"), cnd=cnd.ExactMatch("/start")),
Tr(dst=("main", "reset"), cnd=cnd.ExactMatch("/reset")),
Tr(dst=("main", "print"), priority=0.5),
]
},
"main": {
"start": {RESPONSE: "Hi! Send me email and date."},
"reset": {
PRE_RESPONSE: {"reset_slots": proc.UnsetAll()},
RESPONSE: "All slots have been reset.",
},
"print": {
PRE_RESPONSE: {
"partial_extraction": proc.Extract("partial_extraction"),
# partial extraction is always successful;
# success_only doesn't matter
"success_only_extraction": proc.Extract(
"success_only_extraction", success_only=True
),
# success_only is True by default
"success_only_false": proc.Extract(
"success_only_false", success_only=False
),
"sub_slot_success_only_extraction": proc.Extract(
"sub_slot_success_only_extraction.email",
"sub_slot_success_only_extraction.date",
success_only=True,
),
},
RESPONSE: rsp.FilledTemplate(
"Extracted slots:\n"
" Group with partial extraction:\n"
" {partial_extraction}\n"
" Group with success_only:\n"
" {success_only_extraction}\n"
" Group without success_only:\n"
" {success_only_false}\n"
" Extracting sub-slots with success_only:\n"
" {sub_slot_success_only_extraction}"
),
},
},
}

HAPPY_PATH = [
("/start", "Hi! Send me email and date."),
(
"Only email: [email protected]",
"Extracted slots:\n"
" Group with partial extraction:\n"
" {'date': 'None', 'email': '[email protected]'}\n"
" Group with success_only:\n"
" {'date': 'None', 'email': 'None'}\n"
" Group without success_only:\n"
" {'date': 'None', 'email': '[email protected]'}\n"
" Extracting sub-slots with success_only:\n"
" {'date': 'None', 'email': '[email protected]'}",
),
(
"Only date: 01.01.2024",
"Extracted slots:\n"
" Group with partial extraction:\n"
" {'date': '01.01.2024', 'email': '[email protected]'}\n"
" Group with success_only:\n"
" {'date': 'None', 'email': 'None'}\n"
" Group without success_only:\n"
" {'date': '01.01.2024', 'email': 'None'}\n"
" Extracting sub-slots with success_only:\n"
" {'date': '01.01.2024', 'email': '[email protected]'}",
),
(
"Both email and date: [email protected]; 02.01.2024",
"Extracted slots:\n"
" Group with partial extraction:\n"
" {'date': '02.01.2024', 'email': '[email protected]'}\n"
" Group with success_only:\n"
" {'date': '02.01.2024', 'email': '[email protected]'}\n"
" Group without success_only:\n"
" {'date': '02.01.2024', 'email': '[email protected]'}\n"
" Extracting sub-slots with success_only:\n"
" {'date': '02.01.2024', 'email': '[email protected]'}",
),
(
"Partial update (date only): 03.01.2024",
"Extracted slots:\n"
" Group with partial extraction:\n"
" {'date': '03.01.2024', 'email': '[email protected]'}\n"
" Group with success_only:\n"
" {'date': '02.01.2024', 'email': '[email protected]'}\n"
" Group without success_only:\n"
" {'date': '03.01.2024', 'email': 'None'}\n"
" Extracting sub-slots with success_only:\n"
" {'date': '03.01.2024', 'email': '[email protected]'}",
),
(
"No slots here but `Extract` will still be called.",
"Extracted slots:\n"
" Group with partial extraction:\n"
" {'date': '03.01.2024', 'email': '[email protected]'}\n"
" Group with success_only:\n"
" {'date': '02.01.2024', 'email': '[email protected]'}\n"
" Group without success_only:\n"
" {'date': 'None', 'email': 'None'}\n"
" Extracting sub-slots with success_only:\n"
" {'date': '03.01.2024', 'email': '[email protected]'}",
),
]


# %%
pipeline = Pipeline(
script=script,
start_label=("main", "start"),
slots=SLOTS,
)


# %%
if __name__ == "__main__":
check_happy_path(pipeline, HAPPY_PATH, printout=True)

if is_interactive_mode():
pipeline.run()

0 comments on commit e5e286c

Please sign in to comment.