Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added partially extracted slots support for the GroupSlots #394

Merged
merged 25 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
f1857f6
Added flag to
NotBioWaste905 Oct 1, 2024
c334ff5
First test attempts
NotBioWaste905 Oct 1, 2024
8306bbb
linting
NotBioWaste905 Oct 1, 2024
33f05d0
Added groupslot tutorial for slots
NotBioWaste905 Oct 7, 2024
09937ae
Switched to unit tests
NotBioWaste905 Oct 9, 2024
218e8e9
Lint
NotBioWaste905 Oct 9, 2024
f217832
simplify recursive_setattr
RLKRo Oct 14, 2024
da48f68
update docstrings
RLKRo Oct 14, 2024
a09037c
remove unrelated llm tests
RLKRo Oct 14, 2024
e534f4c
rewrite partial extraction tests
RLKRo Oct 14, 2024
66f3db0
rename allow_partially_extracted to allow_partial_extraction
RLKRo Oct 15, 2024
c489024
add check_happy_path block to tutorial
RLKRo Oct 15, 2024
e6a9468
reformat tutorial
RLKRo Oct 15, 2024
78f3b2b
rewrite tutorial text
RLKRo Oct 15, 2024
4f59a35
Merge branch 'refs/heads/dev' into feat/slots_extraction_update
RLKRo Oct 15, 2024
37b0218
Updated happy path, fixed the script
NotBioWaste905 Oct 22, 2024
7915188
minor text changes
RLKRo Oct 23, 2024
a06ea3c
fix codestyle
RLKRo Oct 23, 2024
204c4e2
Working on tutorial
NotBioWaste905 Oct 30, 2024
8869eef
Added GroupSlotsExtracted class with required field
NotBioWaste905 Oct 31, 2024
376bebd
lint
NotBioWaste905 Oct 31, 2024
713203c
Removed `GroupSlotsExtracted`, updated tutorial
NotBioWaste905 Nov 2, 2024
08dab5d
update tutorial: fix wording
RLKRo Nov 6, 2024
332fa34
update tutorial: change script to showcase behavior with different se…
RLKRo Nov 6, 2024
b31936a
update tutorial: fix wording
RLKRo Nov 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions chatsky/slots/slots.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,17 @@ def two_arg_getattr(__o, name):


def recursive_setattr(obj, slot_name: SlotName, value):
parent_slot, _, slot = slot_name.rpartition(".")
parent_slot, sep, slot = slot_name.rpartition(".")

if parent_slot:
setattr(recursive_getattr(obj, parent_slot), slot, value)
if sep == ".":
parent_obj = recursive_getattr(obj, parent_slot)
else:
setattr(obj, slot, value)
parent_obj = obj

if isinstance(value, ExtractedGroupSlot):
getattr(parent_obj, slot).update(value)
else:
setattr(parent_obj, slot, value)


class SlotNotExtracted(Exception):
Expand Down Expand Up @@ -261,9 +266,11 @@ class GroupSlot(BaseSlot, extra="allow", frozen=True):
"""

__pydantic_extra__: Dict[str, Annotated[Union["GroupSlot", "ValueSlot"], Field(union_mode="left_to_right")]]
allow_partial_extraction: bool = False
"""If True, extraction returns only successfully extracted child slots."""

def __init__(self, **kwargs): # supress unexpected argument warnings
super().__init__(**kwargs)
def __init__(self, allow_partial_extraction=False, **kwargs):
super().__init__(allow_partial_extraction=allow_partial_extraction, **kwargs)

@model_validator(mode="after")
def __check_extra_field_names__(self):
Expand All @@ -279,9 +286,12 @@ def __check_extra_field_names__(self):

async def get_value(self, ctx: Context) -> ExtractedGroupSlot:
child_values = await asyncio.gather(*(child.get_value(ctx) for child in self.__pydantic_extra__.values()))
return ExtractedGroupSlot(
**{child_name: child_value for child_value, child_name in zip(child_values, self.__pydantic_extra__.keys())}
)
extracted_values = {}
for child_value, child_name in zip(child_values, self.__pydantic_extra__.keys()):
if child_value.__slot_extracted__ or not self.allow_partial_extraction:
extracted_values[child_name] = child_value

return ExtractedGroupSlot(**extracted_values)

def init_value(self) -> ExtractedGroupSlot:
return ExtractedGroupSlot(
Expand Down Expand Up @@ -368,6 +378,8 @@ async def extract_slot(self, slot_name: SlotName, ctx: Context, success_only: bo
"""
Extract slot `slot_name` and store extracted value in `slot_storage`.

Extracted group slots update slot storage instead of overwriting it.

:raises KeyError: If the slot with the specified name does not exist.

:param slot_name: Name of the slot to extract.
Expand Down
84 changes: 84 additions & 0 deletions tests/slots/test_slot_partial_extraction.py
RLKRo marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from chatsky.slots import RegexpSlot, GroupSlot
from chatsky.slots.slots import SlotManager
from chatsky.core import Message

import pytest

test_slot = GroupSlot(
root_slot=GroupSlot(
one=RegexpSlot(regexp=r"1"),
two=RegexpSlot(regexp=r"2"),
nested_group=GroupSlot(
three=RegexpSlot(regexp=r"3"),
four=RegexpSlot(regexp=r"4"),
allow_partial_extraction=False,
),
nested_partial_group=GroupSlot(
five=RegexpSlot(regexp=r"5"),
six=RegexpSlot(regexp=r"6"),
allow_partial_extraction=True,
),
allow_partial_extraction=True,
)
)

extracted_slots = {
"root_slot.one": "1",
"root_slot.two": "2",
"root_slot.nested_group.three": "3",
"root_slot.nested_group.four": "4",
"root_slot.nested_partial_group.five": "5",
"root_slot.nested_partial_group.six": "6",
}


@pytest.fixture(scope="function")
def context_with_request(context):
def inner(request):
context.add_request(Message(request))
return context

return inner


@pytest.fixture(scope="function")
def empty_slot_manager():
manager = SlotManager()
manager.set_root_slot(test_slot)
return manager


def get_extracted_slots(manager: SlotManager):
values = []
for slot, value in extracted_slots.items():
extracted_value = manager.get_extracted_slot(slot)
if extracted_value.__slot_extracted__:
if extracted_value.value == value:
values.append(value)
else:
raise RuntimeError(f"Extracted value {extracted_value} does not match expected {value}.")
return values


@pytest.mark.parametrize(
"message,extracted",
[("1 2 3", ["1", "2"]), ("1 3 5", ["1", "5"]), ("3 4 5 6", ["3", "4", "5", "6"])],
)
async def test_partial_extraction(message, extracted, context_with_request, empty_slot_manager):
await empty_slot_manager.extract_slot("root_slot", context_with_request(message), success_only=False)

assert extracted == get_extracted_slots(empty_slot_manager)


async def test_slot_storage_update(context_with_request, empty_slot_manager):
await empty_slot_manager.extract_slot("root_slot", context_with_request("1 3 5"), success_only=False)

assert get_extracted_slots(empty_slot_manager) == ["1", "5"]

await empty_slot_manager.extract_slot("root_slot", context_with_request("2 4 6"), success_only=False)

assert get_extracted_slots(empty_slot_manager) == ["1", "2", "5", "6"]

await empty_slot_manager.extract_slot("root_slot.nested_group", context_with_request("3 4"), success_only=False)

assert get_extracted_slots(empty_slot_manager) == ["1", "2", "3", "4", "5", "6"]
202 changes: 202 additions & 0 deletions tutorials/slots/2_partial_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# %% [markdown]
"""
# 2. Partial slot extraction

This tutorial shows advanced options for slot extraction allowing
to extract only some of the slots.
"""

# %pip install chatsky

# %%
from chatsky import (
RESPONSE,
TRANSITIONS,
PRE_TRANSITION,
PRE_RESPONSE,
GLOBAL,
LOCAL,
Pipeline,
Transition as Tr,
conditions as cnd,
processing as proc,
responses as rsp,
)

from chatsky.slots import RegexpSlot, GroupSlot

from chatsky.utils.testing import (
check_happy_path,
is_interactive_mode,
)

# %% [markdown]
"""
## Default behavior

By default, slot extraction will write a value into slot storage regardless
of whether the extraction was success.
If extraction fails, the slot will be marked as not-extracted
and its value will be the `default_value` (`None` by default).

If group slot is being extracted, the extraction is considered successful
only if all child slots are successfully extracted.

## Success only extraction

The `Extract` function accepts `success_only` flag which makes it so
that extracted value is not saved unless extraction is successful.

This means that unsuccessfully trying to extract a slot after
it has already been extracted will not overwrite the previously extracted
value.

Note that `success_only` is `True` by default.

## Partial group slot extraction

A group slot marked with `allow_partial_extraction` only saves values
of successfully extracted child slots.
Extracting such group slot is equivalent to extracting every child slot
with the `success_only` flag.

Partially extracted group slot is always considered successfully extracted
for the purposes of `success_only` flag.

## Code explanation

In this example we define two group slots: `person` and `friend`.
Note that in the `friend` slot we set `allow_partial_extraction` to `True`
which allows us to _update_ slot values and not
rewrite them in case we don't get full information at once.

So if we send "John Doe" as a full name and after that send only first name
(e.g. "Mike") the extracted friends name would be "Mike Doe"
and not "Mike default_surname".
"""

# %%
SLOTS = {
"person": GroupSlot(
username=RegexpSlot(
regexp=r"([a-zA-Z]+)",
match_group_idx=1,
),
email=RegexpSlot(
regexp=r"([a-z@\.A-Z]+)",
match_group_idx=1,
),
),
"friend": GroupSlot(
first_name=RegexpSlot(
regexp=r"^[A-Z][a-z]+?(?= )", default_value="default_name"
),
last_name=RegexpSlot(
regexp=r"(?<= )[A-Z][a-z]+", default_value="default_surname"
),
allow_partial_extraction=True,
),
}

script = {
GLOBAL: {
TRANSITIONS: [
Tr(dst=("user_flow", "ask"), cnd=cnd.Regexp(r"^[sS]tart"))
]
},
"user_flow": {
LOCAL: {
PRE_TRANSITION: {
"get_slots": proc.Extract("person", success_only=True)
},
TRANSITIONS: [
Tr(
dst=("root", "utter_user"),
cnd=cnd.SlotsExtracted("person", mode="any"),
priority=1.2,
),
Tr(dst=("user_flow", "repeat_question"), priority=0.8),
],
},
"ask": {
RESPONSE: "Please, send your username and email in one message."
},
"repeat_question": {
RESPONSE: "Please, send your username and email again."
},
},
"friend_flow": {
LOCAL: {
PRE_TRANSITION: {
"get_slots": proc.Extract("friend", success_only=False)
RLKRo marked this conversation as resolved.
Show resolved Hide resolved
},
TRANSITIONS: [
Tr(
dst=("root", "utter_friends"),
cnd=cnd.SlotsExtracted(
"friend.first_name", "friend.last_name", mode="any"
),
priority=1.2,
),
Tr(dst=("friend_flow", "repeat_question"), priority=0.8),
],
},
"ask": {RESPONSE: "Please, send your friends name"},
"repeat_question": {RESPONSE: "Please, send your friends name again."},
},
"root": {
"start": {
TRANSITIONS: [Tr(dst=("user_flow", "ask"))],
},
"fallback": {
RESPONSE: "Finishing query",
TRANSITIONS: [Tr(dst=("user_flow", "ask"))],
},
"utter_friend": {
RESPONSE: rsp.FilledTemplate(
"Your friend is {friend.first_name} {friend.last_name}"
),
TRANSITIONS: [Tr(dst=("friend_flow", "ask"))],
},
"utter_user": {
RESPONSE: "Your username is {person.username}. "
"Your email is {person.email}.",
PRE_RESPONSE: {"fill": proc.FillTemplate()},
TRANSITIONS: [Tr(dst=("root", "utter_friend"))],
},
},
}

HAPPY_PATH = [
("hi", "Write your username (my username is ...):"),
RLKRo marked this conversation as resolved.
Show resolved Hide resolved
("my username is groot", "Write your email (my email is ...):"),
(
"my email is [email protected]",
"Please, name me one of your friends: (John Doe)",
),
("Bob Page", "Your friend is Bob Page"),
("ok", "Your username is groot. Your email is [email protected]."),
("ok", "Finishing query"),
(
"again",
"Please, name me one of your friends: (John Doe)",
),
("Jim ", "Your friend is Jim Page"),
]


# %%
pipeline = Pipeline(
script=script,
start_label=("root", "start"),
fallback_label=("root", "fallback"),
slots=SLOTS,
)


# %%
if __name__ == "__main__":
check_happy_path(pipeline, HAPPY_PATH, printout=True)

if is_interactive_mode():
pipeline.run()
Loading