From fb8cd9537083125bfb4169c3b1a1ff154d6aa521 Mon Sep 17 00:00:00 2001 From: Radovan Zivkovic Date: Tue, 18 Jun 2024 12:12:08 +0200 Subject: [PATCH] Format code --- Makefile | 2 +- pyproject.toml | 3 ++- rasa_sdk/cli/arguments.py | 4 +--- rasa_sdk/endpoint.py | 5 ++++- rasa_sdk/executor.py | 2 ++ rasa_sdk/grpc_server.py | 8 +++++--- rasa_sdk/interfaces.py | 21 ++++++++------------- rasa_sdk/utils.py | 2 ++ 8 files changed, 25 insertions(+), 22 deletions(-) diff --git a/Makefile b/Makefile index 42bfecddf..48dd7f577 100644 --- a/Makefile +++ b/Makefile @@ -36,7 +36,7 @@ formatter: lint: poetry run ruff check rasa_sdk tests --ignore D - poetry run black --check rasa_sdk tests + poetry run black --exclude="rasa_sdk/grpc_py" --check rasa_sdk tests make lint-docstrings # Compare against `main` if no branch was provided diff --git a/pyproject.toml b/pyproject.toml index 5a00ed652..0c0d0e8eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api" [tool.black] line-length = 88 target-version = [ "py37", "py38", "py39", "py310",] -exclude = "((.eggs | .git | .mypy_cache | .pytest_cache | build | dist))" +exclude = "((.eggs | .git | .mypy_cache | .pytest_cache | build | dist ))" [tool.poetry] name = "rasa-sdk" @@ -72,6 +72,7 @@ warn_unused_ignores = true ignore = [ "D100", "D104", "D105", "RUF005",] line-length = 88 select = [ "D", "E", "F", "W", "RUF",] +exclude = [ "rasa_sdk/grpc_py" ] [tool.poetry.dependencies] python = ">=3.8,<3.11" diff --git a/rasa_sdk/cli/arguments.py b/rasa_sdk/cli/arguments.py index 5a81f5e03..90e282e2c 100644 --- a/rasa_sdk/cli/arguments.py +++ b/rasa_sdk/cli/arguments.py @@ -82,7 +82,5 @@ def add_endpoint_arguments(parser: argparse.ArgumentParser) -> None: help="Configuration file for the assistant as a yml file.", ) parser.add_argument( - "--grpc", - help="Starts grpc server instead of http", - action="store_true" + "--grpc", help="Starts grpc server instead of http", action="store_true" ) diff --git a/rasa_sdk/endpoint.py b/rasa_sdk/endpoint.py index 1745787bb..0a0326042 100644 --- a/rasa_sdk/endpoint.py +++ b/rasa_sdk/endpoint.py @@ -172,7 +172,10 @@ async def actions(_) -> HTTPResponse: if auto_reload: executor.reload() - body = [action_name_item.model_dump() for action_name_item in executor.list_actions()] # noqa: E501 + body = [ + action_name_item.model_dump() + for action_name_item in executor.list_actions() + ] return response.json(body, status=200) @app.exception(Exception) diff --git a/rasa_sdk/executor.py b/rasa_sdk/executor.py index 2c1c4c2b4..24d4705ec 100644 --- a/rasa_sdk/executor.py +++ b/rasa_sdk/executor.py @@ -172,6 +172,7 @@ def utter_image_url(self, image: Text, **kwargs: Any) -> None: class ActionExecutor: """Executes actions.""" + def __init__(self) -> None: """Initializes the `ActionExecutor`.""" self.actions: Dict[Text, Callable] = {} @@ -516,4 +517,5 @@ def list_actions(self) -> List[ActionName]: class ActionName(BaseModel): """Model for action name.""" + name: str = Field(alias="name") diff --git a/rasa_sdk/grpc_server.py b/rasa_sdk/grpc_server.py index 07f4242e0..20dc62948 100644 --- a/rasa_sdk/grpc_server.py +++ b/rasa_sdk/grpc_server.py @@ -45,7 +45,8 @@ ) from rasa_sdk.utils import ( check_version_compatibility, - number_of_sanic_workers, file_as_bytes, + number_of_sanic_workers, + file_as_bytes, ) logger = logging.getLogger(__name__) @@ -53,6 +54,7 @@ class GRPCActionServerHealthCheck(health_pb2_grpc.HealthServiceServicer): """Runs health check RPC which is served through gRPC server.""" + def __init__(self) -> None: """Initializes the HealthServicer.""" pass @@ -232,8 +234,8 @@ async def run_grpc( f"[::]:{port}", server_credentials=grpc.ssl_server_credentials( private_key_certificate_chain_pairs=[(private_key, certificate_chain)], - root_certificates = ca_cert, - require_client_auth = True if ca_cert else False, + root_certificates=ca_cert, + require_client_auth=True if ca_cert else False, ), ) else: diff --git a/rasa_sdk/interfaces.py b/rasa_sdk/interfaces.py index e21462e46..21791f416 100644 --- a/rasa_sdk/interfaces.py +++ b/rasa_sdk/interfaces.py @@ -23,7 +23,6 @@ class Tracker: @classmethod def from_dict(cls, state: "TrackerState") -> "Tracker": """Create a tracker from dump.""" - return Tracker( state["sender_id"], state.get("slots", {}), @@ -49,7 +48,6 @@ def __init__( stack: Optional[List[Dict[Text, Any]]] = None, ) -> None: """Initialize the tracker.""" - # list of previously seen events self.events = events # id of the source of the messages @@ -72,6 +70,7 @@ def __init__( @property def active_form(self) -> Dict[Text, Any]: + """Get the currently active form.""" warnings.warn( "Use of `active_form` is deprecated. Please use `active_loop insteaad.", DeprecationWarning, @@ -80,7 +79,6 @@ def active_form(self) -> Dict[Text, Any]: def current_state(self) -> Dict[Text, Any]: """Return the current tracker state as an object.""" - if len(self.events) > 0: latest_event_time = self.events[-1].get("timestamp") else: @@ -100,12 +98,11 @@ def current_state(self) -> Dict[Text, Any]: } def current_slot_values(self) -> Dict[Text, Any]: - """Return the currently set values of the slots""" + """Return the currently set values of the slots.""" return self.slots def get_slot(self, key) -> Optional[Any]: """Retrieves the value of a slot.""" - if key in self.slots: return self.slots[key] else: @@ -133,7 +130,6 @@ def get_latest_entity_values( Returns: List of entity values. """ - entities = self.latest_message.get("entities", []) return ( x.get("value") @@ -144,8 +140,7 @@ def get_latest_entity_values( ) def get_latest_input_channel(self) -> Optional[Text]: - """Get the name of the input_channel of the latest UserUttered event""" - + """Get the name of the input_channel of the latest UserUttered event.""" for e in reversed(self.events): if e.get("event") == "user": return e.get("input_channel") @@ -229,7 +224,8 @@ def applied_events(self) -> List[Dict[Text, Any]]: def undo_till_previous(event_type: Text, done_events: List[Dict[Text, Any]]): """Removes events from `done_events` until the first - occurrence `event_type` is found which is also removed.""" + occurrence `event_type` is found which is also removed. + """ # list gets modified - hence we need to copy events! for e in reversed(done_events[:]): del done_events[-1] @@ -262,7 +258,6 @@ def slots_to_validate(self) -> Dict[Text, Any]: Returns: A mapping of extracted slot candidates and their values. """ - slots: Dict[Text, Any] = {} count: int = 0 @@ -331,7 +326,6 @@ class Action: def name(self) -> Text: """Unique identifier of this simple action.""" - raise NotImplementedError("An action must implement a name") async def run( @@ -356,7 +350,6 @@ async def run( A dictionary of `rasa_sdk.events.Event` instances that is returned through the endpoint """ - raise NotImplementedError("An action must implement its run method") def __str__(self) -> Text: @@ -365,7 +358,9 @@ def __str__(self) -> Text: class ActionExecutionRejection(Exception): """Raising this exception will allow other policies - to predict another action""" + to predict another action + . + """ def __init__(self, action_name: Text, message: Optional[Text] = None) -> None: self.action_name = action_name diff --git a/rasa_sdk/utils.py b/rasa_sdk/utils.py index c058e033b..b1404b48e 100644 --- a/rasa_sdk/utils.py +++ b/rasa_sdk/utils.py @@ -33,6 +33,7 @@ class Element(dict): """Represents an element in a list of elements in a rich message.""" + __acceptable_keys = ["title", "item_url", "image_url", "subtitle", "buttons"] def __init__(self, *args, **kwargs): @@ -46,6 +47,7 @@ def __init__(self, *args, **kwargs): class Button(dict): """Represents a button in a rich message.""" + pass