Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ATO-1652] Add gRPC support #1109

Merged
merged 19 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Format code
  • Loading branch information
radovanZRasa committed Jun 18, 2024
commit 1e593efa2cdecc78f8a31c74bfae3f7292eefbee
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ formatter:

lint:
poetry run ruff check rasa_sdk tests --ignore D
poetry run black --check rasa_sdk tests
poetry run black --exclude="rasa_sdk/grpc_py" --check rasa_sdk tests
radovanZRasa marked this conversation as resolved.
Show resolved Hide resolved
make lint-docstrings

# Compare against `main` if no branch was provided
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api"
[tool.black]
line-length = 88
target-version = [ "py37", "py38", "py39", "py310",]
exclude = "((.eggs | .git | .mypy_cache | .pytest_cache | build | dist))"
exclude = "((.eggs | .git | .mypy_cache | .pytest_cache | build | dist ))"

[tool.poetry]
name = "rasa-sdk"
Expand Down Expand Up @@ -67,11 +67,13 @@ ignore_missing_imports = true
show_error_codes = true
warn_redundant_casts = true
warn_unused_ignores = true
exclude = "rasa_sdk/grpc_py"

[tool.ruff]
ignore = [ "D100", "D104", "D105", "RUF005",]
line-length = 88
select = [ "D", "E", "F", "W", "RUF",]
exclude = [ "rasa_sdk/grpc_py" ]

[tool.poetry.dependencies]
python = ">=3.8,<3.11"
Expand Down
4 changes: 1 addition & 3 deletions rasa_sdk/cli/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,5 @@ def add_endpoint_arguments(parser: argparse.ArgumentParser) -> None:
help="Configuration file for the assistant as a yml file.",
)
parser.add_argument(
"--grpc",
help="Starts grpc server instead of http",
action="store_true"
"--grpc", help="Starts grpc server instead of http", action="store_true"
)
5 changes: 4 additions & 1 deletion rasa_sdk/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,10 @@ async def actions(_) -> HTTPResponse:
if auto_reload:
executor.reload()

body = [action_name_item.model_dump() for action_name_item in executor.list_actions()] # noqa: E501
body = [
action_name_item.model_dump()
for action_name_item in executor.list_actions()
]
return response.json(body, status=200)

@app.exception(Exception)
Expand Down
2 changes: 2 additions & 0 deletions rasa_sdk/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def utter_image_url(self, image: Text, **kwargs: Any) -> None:

class ActionExecutor:
"""Executes actions."""

def __init__(self) -> None:
"""Initializes the `ActionExecutor`."""
self.actions: Dict[Text, Callable] = {}
Expand Down Expand Up @@ -516,4 +517,5 @@ def list_actions(self) -> List[ActionName]:

class ActionName(BaseModel):
"""Model for action name."""

name: str = Field(alias="name")
8 changes: 5 additions & 3 deletions rasa_sdk/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,16 @@
)
from rasa_sdk.utils import (
check_version_compatibility,
number_of_sanic_workers, file_as_bytes,
number_of_sanic_workers,
file_as_bytes,
)

logger = logging.getLogger(__name__)


class GRPCActionServerHealthCheck(health_pb2_grpc.HealthServiceServicer):
"""Runs health check RPC which is served through gRPC server."""

def __init__(self) -> None:
"""Initializes the HealthServicer."""
pass
Expand Down Expand Up @@ -232,8 +234,8 @@ async def run_grpc(
f"[::]:{port}",
server_credentials=grpc.ssl_server_credentials(
private_key_certificate_chain_pairs=[(private_key, certificate_chain)],
root_certificates = ca_cert,
require_client_auth = True if ca_cert else False,
root_certificates=ca_cert,
require_client_auth=True if ca_cert else False,
),
)
else:
Expand Down
21 changes: 8 additions & 13 deletions rasa_sdk/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class Tracker:
@classmethod
def from_dict(cls, state: "TrackerState") -> "Tracker":
"""Create a tracker from dump."""

return Tracker(
state["sender_id"],
state.get("slots", {}),
Expand All @@ -49,7 +48,6 @@ def __init__(
stack: Optional[List[Dict[Text, Any]]] = None,
radovanZRasa marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""Initialize the tracker."""

# list of previously seen events
self.events = events
# id of the source of the messages
Expand All @@ -72,6 +70,7 @@ def __init__(

@property
def active_form(self) -> Dict[Text, Any]:
"""Get the currently active form."""
warnings.warn(
"Use of `active_form` is deprecated. Please use `active_loop insteaad.",
DeprecationWarning,
Expand All @@ -80,7 +79,6 @@ def active_form(self) -> Dict[Text, Any]:

def current_state(self) -> Dict[Text, Any]:
"""Return the current tracker state as an object."""

if len(self.events) > 0:
latest_event_time = self.events[-1].get("timestamp")
else:
Expand All @@ -100,12 +98,11 @@ def current_state(self) -> Dict[Text, Any]:
}

def current_slot_values(self) -> Dict[Text, Any]:
"""Return the currently set values of the slots"""
"""Return the currently set values of the slots."""
return self.slots

def get_slot(self, key) -> Optional[Any]:
"""Retrieves the value of a slot."""

if key in self.slots:
return self.slots[key]
else:
Expand Down Expand Up @@ -133,7 +130,6 @@ def get_latest_entity_values(
Returns:
List of entity values.
"""

entities = self.latest_message.get("entities", [])
return (
x.get("value")
Expand All @@ -144,8 +140,7 @@ def get_latest_entity_values(
)

def get_latest_input_channel(self) -> Optional[Text]:
"""Get the name of the input_channel of the latest UserUttered event"""

"""Get the name of the input_channel of the latest UserUttered event."""
for e in reversed(self.events):
if e.get("event") == "user":
return e.get("input_channel")
Expand Down Expand Up @@ -229,7 +224,8 @@ def applied_events(self) -> List[Dict[Text, Any]]:

def undo_till_previous(event_type: Text, done_events: List[Dict[Text, Any]]):
"""Removes events from `done_events` until the first
occurrence `event_type` is found which is also removed."""
occurrence `event_type` is found which is also removed.
"""
# list gets modified - hence we need to copy events!
for e in reversed(done_events[:]):
del done_events[-1]
Expand Down Expand Up @@ -262,7 +258,6 @@ def slots_to_validate(self) -> Dict[Text, Any]:
Returns:
A mapping of extracted slot candidates and their values.
"""

slots: Dict[Text, Any] = {}
count: int = 0

Expand Down Expand Up @@ -331,7 +326,6 @@ class Action:

def name(self) -> Text:
"""Unique identifier of this simple action."""

raise NotImplementedError("An action must implement a name")

async def run(
Expand All @@ -356,7 +350,6 @@ async def run(
A dictionary of `rasa_sdk.events.Event` instances that is
returned through the endpoint
"""

raise NotImplementedError("An action must implement its run method")

def __str__(self) -> Text:
Expand All @@ -365,7 +358,9 @@ def __str__(self) -> Text:

class ActionExecutionRejection(Exception):
"""Raising this exception will allow other policies
to predict another action"""
to predict another action
.
"""

def __init__(self, action_name: Text, message: Optional[Text] = None) -> None:
self.action_name = action_name
Expand Down
2 changes: 2 additions & 0 deletions rasa_sdk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

class Element(dict):
"""Represents an element in a list of elements in a rich message."""

__acceptable_keys = ["title", "item_url", "image_url", "subtitle", "buttons"]

def __init__(self, *args, **kwargs):
Expand All @@ -46,6 +47,7 @@ def __init__(self, *args, **kwargs):

class Button(dict):
"""Represents a button in a rich message."""

pass


Expand Down
Loading