Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added partially extracted slots support for the GroupSlots #394

Merged
merged 25 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
f1857f6
Added flag to
NotBioWaste905 Oct 1, 2024
c334ff5
First test attempts
NotBioWaste905 Oct 1, 2024
8306bbb
linting
NotBioWaste905 Oct 1, 2024
33f05d0
Added groupslot tutorial for slots
NotBioWaste905 Oct 7, 2024
09937ae
Switched to unit tests
NotBioWaste905 Oct 9, 2024
218e8e9
Lint
NotBioWaste905 Oct 9, 2024
f217832
simplify recursive_setattr
RLKRo Oct 14, 2024
da48f68
update docstrings
RLKRo Oct 14, 2024
a09037c
remove unrelated llm tests
RLKRo Oct 14, 2024
e534f4c
rewrite partial extraction tests
RLKRo Oct 14, 2024
66f3db0
rename allow_partially_extracted to allow_partial_extraction
RLKRo Oct 15, 2024
c489024
add check_happy_path block to tutorial
RLKRo Oct 15, 2024
e6a9468
reformat tutorial
RLKRo Oct 15, 2024
78f3b2b
rewrite tutorial text
RLKRo Oct 15, 2024
4f59a35
Merge branch 'refs/heads/dev' into feat/slots_extraction_update
RLKRo Oct 15, 2024
37b0218
Updated happy path, fixed the script
NotBioWaste905 Oct 22, 2024
7915188
minor text changes
RLKRo Oct 23, 2024
a06ea3c
fix codestyle
RLKRo Oct 23, 2024
204c4e2
Working on tutorial
NotBioWaste905 Oct 30, 2024
8869eef
Added GroupSlotsExtracted class with required field
NotBioWaste905 Oct 31, 2024
376bebd
lint
NotBioWaste905 Oct 31, 2024
713203c
Removed `GroupSlotsExtracted`, updated tutorial
NotBioWaste905 Nov 2, 2024
08dab5d
update tutorial: fix wording
RLKRo Nov 6, 2024
332fa34
update tutorial: change script to showcase behavior with different se…
RLKRo Nov 6, 2024
b31936a
update tutorial: fix wording
RLKRo Nov 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions chatsky/slots/slots.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,17 @@ def two_arg_getattr(__o, name):


def recursive_setattr(obj, slot_name: SlotName, value):
parent_slot, _, slot = slot_name.rpartition(".")
parent_slot, sep, slot = slot_name.rpartition(".")

if parent_slot:
setattr(recursive_getattr(obj, parent_slot), slot, value)
if sep == ".":
parent_obj = recursive_getattr(obj, parent_slot)
else:
setattr(obj, slot, value)
parent_obj = obj

if isinstance(value, ExtractedGroupSlot):
getattr(parent_obj, slot).update(value)
else:
setattr(parent_obj, slot, value)


class SlotNotExtracted(Exception):
Expand Down Expand Up @@ -261,9 +266,11 @@ class GroupSlot(BaseSlot, extra="allow", frozen=True):
"""

__pydantic_extra__: Dict[str, Annotated[Union["GroupSlot", "ValueSlot"], Field(union_mode="left_to_right")]]
allow_partial_extraction: bool = False
"""If True, extraction returns only successfully extracted child slots."""

def __init__(self, **kwargs): # supress unexpected argument warnings
super().__init__(**kwargs)
def __init__(self, allow_partial_extraction=False, **kwargs):
super().__init__(allow_partial_extraction=allow_partial_extraction, **kwargs)

@model_validator(mode="after")
def __check_extra_field_names__(self):
Expand All @@ -279,9 +286,12 @@ def __check_extra_field_names__(self):

async def get_value(self, ctx: Context) -> ExtractedGroupSlot:
child_values = await asyncio.gather(*(child.get_value(ctx) for child in self.__pydantic_extra__.values()))
return ExtractedGroupSlot(
**{child_name: child_value for child_value, child_name in zip(child_values, self.__pydantic_extra__.keys())}
)
extracted_values = {}
for child_value, child_name in zip(child_values, self.__pydantic_extra__.keys()):
if child_value.__slot_extracted__ or not self.allow_partial_extraction:
extracted_values[child_name] = child_value

return ExtractedGroupSlot(**extracted_values)

def init_value(self) -> ExtractedGroupSlot:
return ExtractedGroupSlot(
Expand Down Expand Up @@ -368,6 +378,8 @@ async def extract_slot(self, slot_name: SlotName, ctx: Context, success_only: bo
"""
Extract slot `slot_name` and store extracted value in `slot_storage`.

Extracted group slots update slot storage instead of overwriting it.

:raises KeyError: If the slot with the specified name does not exist.

:param slot_name: Name of the slot to extract.
Expand Down
84 changes: 84 additions & 0 deletions tests/slots/test_slot_partial_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from chatsky.slots import RegexpSlot, GroupSlot
from chatsky.slots.slots import SlotManager
from chatsky.core import Message

import pytest

test_slot = GroupSlot(
root_slot=GroupSlot(
one=RegexpSlot(regexp=r"1"),
two=RegexpSlot(regexp=r"2"),
nested_group=GroupSlot(
three=RegexpSlot(regexp=r"3"),
four=RegexpSlot(regexp=r"4"),
allow_partial_extraction=False,
),
nested_partial_group=GroupSlot(
five=RegexpSlot(regexp=r"5"),
six=RegexpSlot(regexp=r"6"),
allow_partial_extraction=True,
),
allow_partial_extraction=True,
)
)

extracted_slots = {
"root_slot.one": "1",
"root_slot.two": "2",
"root_slot.nested_group.three": "3",
"root_slot.nested_group.four": "4",
"root_slot.nested_partial_group.five": "5",
"root_slot.nested_partial_group.six": "6",
}


@pytest.fixture(scope="function")
def context_with_request(context):
def inner(request):
context.add_request(Message(request))
return context

return inner


@pytest.fixture(scope="function")
def empty_slot_manager():
manager = SlotManager()
manager.set_root_slot(test_slot)
return manager


def get_extracted_slots(manager: SlotManager):
values = []
for slot, value in extracted_slots.items():
extracted_value = manager.get_extracted_slot(slot)
if extracted_value.__slot_extracted__:
if extracted_value.value == value:
values.append(value)
else:
raise RuntimeError(f"Extracted value {extracted_value} does not match expected {value}.")
return values


@pytest.mark.parametrize(
"message,extracted",
[("1 2 3", ["1", "2"]), ("1 3 5", ["1", "5"]), ("3 4 5 6", ["3", "4", "5", "6"])],
)
async def test_partial_extraction(message, extracted, context_with_request, empty_slot_manager):
await empty_slot_manager.extract_slot("root_slot", context_with_request(message), success_only=False)

assert extracted == get_extracted_slots(empty_slot_manager)


async def test_slot_storage_update(context_with_request, empty_slot_manager):
await empty_slot_manager.extract_slot("root_slot", context_with_request("1 3 5"), success_only=False)

assert get_extracted_slots(empty_slot_manager) == ["1", "5"]

await empty_slot_manager.extract_slot("root_slot", context_with_request("2 4 6"), success_only=False)

assert get_extracted_slots(empty_slot_manager) == ["1", "2", "5", "6"]

await empty_slot_manager.extract_slot("root_slot.nested_group", context_with_request("3 4"), success_only=False)

assert get_extracted_slots(empty_slot_manager) == ["1", "2", "3", "4", "5", "6"]
236 changes: 236 additions & 0 deletions tutorials/slots/2_partial_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
# %% [markdown]
"""
# 2. Partial slot extraction

This tutorial shows advanced options for slot extraction allowing
to extract only some of the slots.
"""

# %pip install chatsky

# %%
from chatsky import (
RESPONSE,
TRANSITIONS,
PRE_RESPONSE,
GLOBAL,
Pipeline,
Transition as Tr,
conditions as cnd,
processing as proc,
responses as rsp,
)

from chatsky.slots import RegexpSlot, GroupSlot

from chatsky.utils.testing import (
check_happy_path,
is_interactive_mode,
)

# %% [markdown]
"""
## Extracted values

Result of successful slot extraction is the extracted value, *but*
if the extraction fails, the slot will be marked as "not extracted"
and its value will be set to the slot's `default_value` (`None` by default).

If group slot is being extracted, the extraction is considered successful
if and only if all child slots are successfully extracted.

## Success only extraction

The `Extract` function accepts `success_only` flag which makes it so
that extracted value is not saved unless extraction is successful.

This means that unsuccessfully trying to extract a slot will not overwrite
its previously extracted value.

Note that `success_only` is `True` by default.

## Partial group slot extraction

A group slot marked with `allow_partial_extraction` only saves values
of successfully extracted child slots.
Extracting such group slot is equivalent to extracting every child slot
with the `success_only` flag.

Partially extracted group slot is always considered successfully extracted
for the purposes of the `success_only` flag.

## Code explanation

In this example we showcase the behavior of
different group slot extraction settings:

Group `partial_extraction` is marked with `allow_partial_extraction`.
Any slot in this group is saved if and only if that slot is successfully
extracted.

Group `success_only_extraction` is extracted with the `success_only`
flag set to True.
Any slot in this group is saved if and only if all of the slots in the group
are successfully extracted within a single `Extract` call.

Group `success_only_false` is extracted with the `success_only` set to False.
Every slot in this group is saved (even if extraction was not successful).

Group `sub_slot_success_only_extraction` is extracted by passing all of its
child slots to the `Extract` method with the `success_only` flag set to True.
The behavior is equivalent to that of `partial_extraction`.
"""

# %%
sub_slots = {
"date": RegexpSlot(
regexp=r"(0?[1-9]|(?:1|2)[0-9]|3[0-1])[\.\/]"
r"(0?[1-9]|1[0-2])[\.\/](\d{4}|\d{2})",
),
"email": RegexpSlot(
regexp=r"[\w\.-]+@[\w\.-]+\.\w{2,4}",
),
}

SLOTS = {
"partial_extraction": GroupSlot(
**sub_slots,
allow_partial_extraction=True,
),
"success_only_extraction": GroupSlot(
**sub_slots,
),
"success_only_false": GroupSlot(
**sub_slots,
),
"sub_slot_success_only_extraction": GroupSlot(
**sub_slots,
),
}

script = {
GLOBAL: {
TRANSITIONS: [
Tr(dst=("main", "start"), cnd=cnd.ExactMatch("/start")),
Tr(dst=("main", "reset"), cnd=cnd.ExactMatch("/reset")),
Tr(dst=("main", "print"), priority=0.5),
]
},
"main": {
"start": {RESPONSE: "Hi! Send me email and date."},
"reset": {
PRE_RESPONSE: {"reset_slots": proc.UnsetAll()},
RESPONSE: "All slots have been reset.",
},
"print": {
PRE_RESPONSE: {
"partial_extraction": proc.Extract("partial_extraction"),
# partial extraction is always successful;
# success_only doesn't matter
"success_only_extraction": proc.Extract(
"success_only_extraction", success_only=True
),
# success_only is True by default
"success_only_false": proc.Extract(
"success_only_false", success_only=False
),
"sub_slot_success_only_extraction": proc.Extract(
"sub_slot_success_only_extraction.email",
"sub_slot_success_only_extraction.date",
success_only=True,
),
},
RESPONSE: rsp.FilledTemplate(
"Extracted slots:\n"
" Group with partial extraction:\n"
" {partial_extraction}\n"
" Group with success_only:\n"
" {success_only_extraction}\n"
" Group without success_only:\n"
" {success_only_false}\n"
" Extracting sub-slots with success_only:\n"
" {sub_slot_success_only_extraction}"
),
},
},
}

HAPPY_PATH = [
("/start", "Hi! Send me email and date."),
(
"Only email: [email protected]",
"Extracted slots:\n"
" Group with partial extraction:\n"
" {'date': 'None', 'email': '[email protected]'}\n"
" Group with success_only:\n"
" {'date': 'None', 'email': 'None'}\n"
" Group without success_only:\n"
" {'date': 'None', 'email': '[email protected]'}\n"
" Extracting sub-slots with success_only:\n"
" {'date': 'None', 'email': '[email protected]'}",
),
(
"Only date: 01.01.2024",
"Extracted slots:\n"
" Group with partial extraction:\n"
" {'date': '01.01.2024', 'email': '[email protected]'}\n"
" Group with success_only:\n"
" {'date': 'None', 'email': 'None'}\n"
" Group without success_only:\n"
" {'date': '01.01.2024', 'email': 'None'}\n"
" Extracting sub-slots with success_only:\n"
" {'date': '01.01.2024', 'email': '[email protected]'}",
),
(
"Both email and date: [email protected]; 02.01.2024",
"Extracted slots:\n"
" Group with partial extraction:\n"
" {'date': '02.01.2024', 'email': '[email protected]'}\n"
" Group with success_only:\n"
" {'date': '02.01.2024', 'email': '[email protected]'}\n"
" Group without success_only:\n"
" {'date': '02.01.2024', 'email': '[email protected]'}\n"
" Extracting sub-slots with success_only:\n"
" {'date': '02.01.2024', 'email': '[email protected]'}",
),
(
"Partial update (date only): 03.01.2024",
"Extracted slots:\n"
" Group with partial extraction:\n"
" {'date': '03.01.2024', 'email': '[email protected]'}\n"
" Group with success_only:\n"
" {'date': '02.01.2024', 'email': '[email protected]'}\n"
" Group without success_only:\n"
" {'date': '03.01.2024', 'email': 'None'}\n"
" Extracting sub-slots with success_only:\n"
" {'date': '03.01.2024', 'email': '[email protected]'}",
),
(
"No slots here but `Extract` will still be called.",
"Extracted slots:\n"
" Group with partial extraction:\n"
" {'date': '03.01.2024', 'email': '[email protected]'}\n"
" Group with success_only:\n"
" {'date': '02.01.2024', 'email': '[email protected]'}\n"
" Group without success_only:\n"
" {'date': 'None', 'email': 'None'}\n"
" Extracting sub-slots with success_only:\n"
" {'date': '03.01.2024', 'email': '[email protected]'}",
),
]


# %%
pipeline = Pipeline(
script=script,
start_label=("main", "start"),
slots=SLOTS,
)


# %%
if __name__ == "__main__":
check_happy_path(pipeline, HAPPY_PATH, printout=True)

if is_interactive_mode():
pipeline.run()
Loading