diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index cd0f32186..744e45e6b 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -6,4 +6,4 @@ - [ ] added some tests for the functionality - [ ] updated the documentation in the [rasaHQ/rasa](https://github.com/rasaHQ/rasa) - [ ] updated the changelog (please check [changelog](https://github.com/RasaHQ/rasa-sdk/tree/main/changelog) for instructions) -- [ ] reformat files using `black` (please check [Readme](https://github.com/RasaHQ/rasa-sdk#code-style) for instructions) +- [ ] reformat files using `ruff` (please check [Readme](https://github.com/RasaHQ/rasa-sdk#code-style) for instructions) diff --git a/Makefile b/Makefile index a48203ca6..cd36690b2 100644 --- a/Makefile +++ b/Makefile @@ -27,11 +27,11 @@ types: ## check types poetry run mypy rasa_sdk formatter: ## format code - poetry run black rasa_sdk tests + poetry run ruff format rasa_sdk tests lint: ## check style with ruff and black poetry run ruff check rasa_sdk tests --ignore D - poetry run black --exclude="rasa_sdk/grpc_py" --check rasa_sdk tests + poetry run ruff format --check rasa_sdk tests make lint-docstrings make check-generate-grpc-code-in-sync diff --git a/README.md b/README.md index a400129c2..004b45f3e 100644 --- a/README.md +++ b/README.md @@ -108,7 +108,7 @@ make install ## Code Style -To ensure a standardized code style we use the formatter [black](https://github.com/ambv/black). +To ensure a standardized code style we use the formatter [ruff](https://github.com/astral-sh/ruff). If your code is not formatted properly, GitHub CI will fail to build. If you want to automatically format your code on every commit, you can use [pre-commit](https://pre-commit.com/). diff --git a/changelog/1130.misc.md b/changelog/1130.misc.md new file mode 100644 index 000000000..c759bc6d2 --- /dev/null +++ b/changelog/1130.misc.md @@ -0,0 +1,2 @@ +Update ruff to 0.3.7. +Switch to ruff as code formatter. diff --git a/poetry.lock b/poetry.lock index 51b425790..7cace77ca 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1505,28 +1505,28 @@ files = [ [[package]] name = "ruff" -version = "0.0.256" -description = "An extremely fast Python linter, written in Rust." +version = "0.3.7" +description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.0.256-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:eb8e949f6e7fb16f9aa163fcc13318e2b7910577513468417e5b003b984410a1"}, - {file = "ruff-0.0.256-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:48a42f0ec4c5a3c3b062e947b2a5f8f7a4264761653fb0ee656a9b535ae6d8d7"}, - {file = "ruff-0.0.256-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:36ca633cfc335869643a13e2006f13a63bc4cb94073aa9508ceb08a1e3afe3af"}, - {file = "ruff-0.0.256-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:80fa5d3a40dd0b65c6d6adea4f825984d5d3a215a25d90cc6139978cb22ea1cd"}, - {file = "ruff-0.0.256-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a0f88839b886db3577136375865bd080b9ed6f9b85bb990d897780e5a30ca3c2"}, - {file = "ruff-0.0.256-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:fe6d77a43b2d52f45ee42f6f682198ed1c34cd0165812e276648981dfd50ad36"}, - {file = "ruff-0.0.256-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3878593507b281b2615702ece06426e8b27076e8fedf658bf0c5e1e5e2ad1b40"}, - {file = "ruff-0.0.256-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9e052ec4d5c92663caa662b68fe1902ec10eddac2783229b1c5f20f3df62a865"}, - {file = "ruff-0.0.256-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2116bd67e52ade9f90e5a3a3aa511a9b179c699690221bdd5bb267dbf7e94b22"}, - {file = "ruff-0.0.256-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:3c6e93d7818a75669328e49a0f7070c40e18676ca8e56ca9c566633bef4d8d05"}, - {file = "ruff-0.0.256-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:7ebb7de4e62d751b65bb15418a83ac5d555afb3eaa3ad549dea21744da34ae86"}, - {file = "ruff-0.0.256-py3-none-musllinux_1_2_i686.whl", hash = "sha256:f310bfc76c0404a487759c8904f57bf51653c46e686c800efc1ff1d165a59a04"}, - {file = "ruff-0.0.256-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:93a0cfec812b2ba57bff22b176901e0ddf44e4d42a9bd7da7ffb5e53df13fd6e"}, - {file = "ruff-0.0.256-py3-none-win32.whl", hash = "sha256:d63e5320bc2d91e94925cd1863e381a48edf087041035967faf2614bb36a6a0d"}, - {file = "ruff-0.0.256-py3-none-win_amd64.whl", hash = "sha256:859c8ffb1801895fe043a2b85a45cd0ff35667ddea4b465ba2a29d275550814a"}, - {file = "ruff-0.0.256-py3-none-win_arm64.whl", hash = "sha256:64b276149e86c3d234608d3fe1da77535865e03debd3a1d5d04576f7f5031bbb"}, - {file = "ruff-0.0.256.tar.gz", hash = "sha256:f9a96b34a4870ee8cf2f3779cd7854620d1788a83b52374771266cf800541bb7"}, + {file = "ruff-0.3.7-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:0e8377cccb2f07abd25e84fc5b2cbe48eeb0fea9f1719cad7caedb061d70e5ce"}, + {file = "ruff-0.3.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:15a4d1cc1e64e556fa0d67bfd388fed416b7f3b26d5d1c3e7d192c897e39ba4b"}, + {file = "ruff-0.3.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d28bdf3d7dc71dd46929fafeec98ba89b7c3550c3f0978e36389b5631b793663"}, + {file = "ruff-0.3.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:379b67d4f49774ba679593b232dcd90d9e10f04d96e3c8ce4a28037ae473f7bb"}, + {file = "ruff-0.3.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c060aea8ad5ef21cdfbbe05475ab5104ce7827b639a78dd55383a6e9895b7c51"}, + {file = "ruff-0.3.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:ebf8f615dde968272d70502c083ebf963b6781aacd3079081e03b32adfe4d58a"}, + {file = "ruff-0.3.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d48098bd8f5c38897b03604f5428901b65e3c97d40b3952e38637b5404b739a2"}, + {file = "ruff-0.3.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:da8a4fda219bf9024692b1bc68c9cff4b80507879ada8769dc7e985755d662ea"}, + {file = "ruff-0.3.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c44e0149f1d8b48c4d5c33d88c677a4aa22fd09b1683d6a7ff55b816b5d074f"}, + {file = "ruff-0.3.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:3050ec0af72b709a62ecc2aca941b9cd479a7bf2b36cc4562f0033d688e44fa1"}, + {file = "ruff-0.3.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:a29cc38e4c1ab00da18a3f6777f8b50099d73326981bb7d182e54a9a21bb4ff7"}, + {file = "ruff-0.3.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:5b15cc59c19edca917f51b1956637db47e200b0fc5e6e1878233d3a938384b0b"}, + {file = "ruff-0.3.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e491045781b1e38b72c91247cf4634f040f8d0cb3e6d3d64d38dcf43616650b4"}, + {file = "ruff-0.3.7-py3-none-win32.whl", hash = "sha256:bc931de87593d64fad3a22e201e55ad76271f1d5bfc44e1a1887edd0903c7d9f"}, + {file = "ruff-0.3.7-py3-none-win_amd64.whl", hash = "sha256:5ef0e501e1e39f35e03c2acb1d1238c595b8bb36cf7a170e7c1df1b73da00e74"}, + {file = "ruff-0.3.7-py3-none-win_arm64.whl", hash = "sha256:789e144f6dc7019d1f92a812891c645274ed08af6037d11fc65fcbc183b7d59f"}, + {file = "ruff-0.3.7.tar.gz", hash = "sha256:d5c1aebee5162c2226784800ae031f660c350e7a3402c4d1f8ea4e97e232e3ba"}, ] [[package]] @@ -2058,4 +2058,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.8,<3.11" -content-hash = "73c3b65eb0052bae1164bd81d1bfb8e4e2703cd708037ea9d84e7d2929cbaf7f" +content-hash = "2df8a1c5cf589611c3603c5600b0aebfa3d3a44b40113f9a31b82c8a8afca066" diff --git a/pyproject.toml b/pyproject.toml index 322b8a0da..fa8bf036a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,11 +2,6 @@ requires = [ "poetry-core>=1.0.4",] build-backend = "poetry.core.masonry.api" -[tool.black] -line-length = 88 -target-version = [ "py37", "py38", "py39", "py310",] -exclude = "((.eggs | .git | .mypy_cache | .pytest_cache | build | dist ))" - [tool.poetry] name = "rasa-sdk" version = "3.9.0" @@ -70,10 +65,13 @@ warn_unused_ignores = true exclude = "rasa_sdk/grpc_py" [tool.ruff] -ignore = [ "D100", "D104", "D105", "RUF005",] line-length = 88 +target-version = "py38" +exclude = [ "rasa_sdk/grpc_py", "eggs", ".git", ".pytest_cache", "build", "dist", ".DS_Store"] + +[tool.ruff.lint] +ignore = [ "D100", "D101", "D102", "D103", "D104", "D105", "RUF005",] select = [ "D", "E", "F", "W", "RUF",] -exclude = [ "rasa_sdk/grpc_py",] [tool.poetry.dependencies] python = ">=3.8,<3.11" @@ -108,7 +106,7 @@ semantic_version = "^2.8.5" mypy = "^1.5" sanic-testing = "^22.12" -[tool.ruff.pydocstyle] +[tool.ruff.lint.pydocstyle] convention = "google" [tool.pytest.ini_options] @@ -116,6 +114,6 @@ python_functions = "test_" asyncio_mode = "auto" [tool.poetry.group.dev.dependencies] -ruff = ">=0.0.256,<0.0.286" +ruff = ">=0.3.5,<0.4.0" pytest-asyncio = "^0.21.0" types-protobuf = "4.25.0.20240417" diff --git a/rasa_sdk/interfaces.py b/rasa_sdk/interfaces.py index 21d20116a..3ac7c30d7 100644 --- a/rasa_sdk/interfaces.py +++ b/rasa_sdk/interfaces.py @@ -115,8 +115,9 @@ def get_latest_entity_values( entity_role: Optional[Text] = None, entity_group: Optional[Text] = None, ) -> Iterator[Text]: - """Get entity values found for the passed entity type and optional role and - group in latest message. + """Get entity values found for the passed entity type. + + Optionally role and group of the entities in the last message can be specified. If you are only interested in the first entity of a given type use `next(tracker.get_latest_entity_values("my_entity_name"), None)`. @@ -222,9 +223,17 @@ def filter_function(e: Dict[Text, Any]) -> bool: def applied_events(self) -> List[Dict[Text, Any]]: """Returns all actions that should be applied - w/o reverted events.""" - def undo_till_previous(event_type: Text, done_events: List[Dict[Text, Any]]): - """Removes events from `done_events` until the first - occurrence `event_type` is found which is also removed. + def undo_till_previous( + event_type: Text, done_events: List[Dict[Text, Any]] + ) -> None: + """Removes events from `done_events` until `event_type` is found. + + Removes all events until first occurrence of an `event_type` is found + including the `event_type`. + + Args: + event_type: The type of event to remove. + done_events: The list of events to remove the event from. """ # list gets modified - hence we need to copy events! for e in reversed(done_events[:]): @@ -357,27 +366,33 @@ def __str__(self) -> Text: class ActionExecutionRejection(Exception): - """Raising this exception will allow other policies - to predict another action - . - """ + """Raising this exception will allow other policies to predict another action.""" def __init__(self, action_name: Text, message: Optional[Text] = None) -> None: + """Create a rejection exception. + + Args: + action_name: Name of the action that should be rejected. + message: Optional message to provide more information + """ self.action_name = action_name self.message = message or f"Custom action '{action_name}' rejected execution." def __str__(self) -> Text: + """Return the string representation of the exception.""" return self.message class ActionNotFoundException(Exception): def __init__(self, action_name: Text, message: Optional[Text] = None) -> None: + """Create an exception for when an action is not found.""" self.action_name = action_name self.message = ( message or f"No registered action found for name '{action_name}'." ) def __str__(self) -> Text: + """Return the string representation of the exception.""" return self.message @@ -385,6 +400,7 @@ class ActionMissingDomainException(Exception): """Raising this exception when the domain is missing.""" def __init__(self, action_name: Text, message: Optional[Text] = None) -> None: + """Create an exception for when the domain is missing.""" self.action_name = action_name self.message = ( message @@ -394,4 +410,5 @@ def __init__(self, action_name: Text, message: Optional[Text] = None) -> None: ) def __str__(self) -> Text: + """Return the string representation of the exception.""" return self.message diff --git a/rasa_sdk/knowledge_base/actions.py b/rasa_sdk/knowledge_base/actions.py index 097bd5031..5b57a4026 100644 --- a/rasa_sdk/knowledge_base/actions.py +++ b/rasa_sdk/knowledge_base/actions.py @@ -25,8 +25,8 @@ class ActionQueryKnowledgeBase(Action): - """ - Action that queries the knowledge base for objects and attributes of an object. + """Action that queries the knowledge base for objects and attributes of an object. + The action needs to be inherited and the knowledge base needs to be set. In order to actually query the knowledge base you need to: - create your knowledge base @@ -39,10 +39,12 @@ class ActionQueryKnowledgeBase(Action): def __init__( self, knowledge_base: KnowledgeBase, use_last_object_mention: bool = True ) -> None: + """Creates an action that queries the knowledge base.""" self.knowledge_base = knowledge_base self.use_last_object_mention = use_last_object_mention def name(self) -> Text: + """Returns the unique identifier of this action.""" return "action_query_knowledge_base" def utter_attribute_value( @@ -52,12 +54,10 @@ def utter_attribute_value( attribute_name: Text, attribute_value: Text, ): - """ - Utters a response that informs the user about the attribute value of the - attribute of interest. + """Utters a response that informs the user about the value of an attribute. Args: - dispatcher: the dispatcher + dispatcher: the collecting dispatcher object_name: the name of the object attribute_name: the name of the attribute attribute_value: the value of the attribute @@ -83,8 +83,7 @@ async def utter_objects( object_type: Text, objects: List[Dict[Text, Any]], ): - """ - Utters a response to the user that lists all found objects. + """Utters a response to the user that lists all found objects. Args: dispatcher: the dispatcher @@ -113,11 +112,13 @@ async def run( tracker: Tracker, domain: "DomainDict", ) -> List[Dict[Text, Any]]: - """ - Executes this action. If the user ask a question about an attribute, - the knowledge base is queried for that attribute. Otherwise, if no - attribute was detected in the latest request it assumes user is talking - about a new object type and, multiple objects of the requested type are + """Executes the action. + + If the user ask a question about an attribute, + the knowledge base is queried for that attribute. + Otherwise, if no attribute was detected in the latest + request it assumes user is talking about a new object type and, + multiple objects of the requested type are returned from the knowledge base. Args: @@ -126,7 +127,6 @@ async def run( domain: the domain Returns: list of slots - """ object_type = tracker.get_slot(SLOT_OBJECT_TYPE) last_object_type = tracker.get_slot(SLOT_LAST_OBJECT_TYPE) @@ -168,13 +168,15 @@ async def run( async def _query_objects( self, dispatcher: CollectingDispatcher, object_type: Text, tracker: Tracker ) -> List[Dict]: - """ + """Queries the knowledge base for objects of the requested object type. + Queries the knowledge base for objects of the requested object type and outputs those to the user. The objects are filtered by any attribute the user mentioned in the request. Args: dispatcher: the dispatcher + object_type: the object types tracker: the tracker Returns: list of slots @@ -227,12 +229,15 @@ async def _query_attribute( attribute: Text, tracker: Tracker, ) -> List[Dict]: - """ + """Query the knowledge base using value of the attribute of the object. + Queries the knowledge base for the value of the requested attribute of the mentioned object and outputs it to the user. Args: dispatcher: the dispatcher + object_type: the object type + attribute: the requested attribute tracker: the tracker Returns: list of slots diff --git a/rasa_sdk/knowledge_base/storage.py b/rasa_sdk/knowledge_base/storage.py index 22f31e853..4c7f3c012 100644 --- a/rasa_sdk/knowledge_base/storage.py +++ b/rasa_sdk/knowledge_base/storage.py @@ -12,6 +12,7 @@ class KnowledgeBase: def __init__(self) -> None: + """Create a new knowledge base.""" self.ordinal_mention_mapping = { "1": lambda lst: lst[0], "2": lambda lst: lst[1], @@ -33,8 +34,7 @@ def __init__(self) -> None: ] = defaultdict(lambda: lambda obj: obj["name"]) async def get_attributes_of_object(self, object_type: Text) -> List[Text]: - """ - Returns a list of all attributes that belong to the provided object type. + """Returns a list of all attributes that belong to the provided object type. Args: object_type: the object type @@ -44,8 +44,7 @@ async def get_attributes_of_object(self, object_type: Text) -> List[Text]: raise NotImplementedError("Method is not implemented.") async def get_key_attribute_of_object(self, object_type: Text) -> Text: - """ - Returns the key attribute for the given object type. + """Returns the key attribute for the given object type. Args: object_type: the object type @@ -57,21 +56,20 @@ async def get_key_attribute_of_object(self, object_type: Text) -> Text: async def get_representation_function_of_object( self, object_type: Text ) -> Callable: - """ - Returns a lamdba function that takes the object and returns a string - representation of it. + """Returns a lambda function that returns string representation of the object. Args: object_type: the object type - Returns: lamdba function + Returns: lambda function """ return self.representation_function[object_type] def set_ordinal_mention_mapping(self, mapping: Dict[Text, Callable]) -> None: - """ - Overwrites the default ordinal mention mapping. E.g. the mapping that - maps, for example, "first one" to the first element in a list. + """Overwrites the default ordinal mention mapping. + + E.g. the mapping that maps, for example, + "first one" to the first element in a list. Args: mapping: the ordinal mention mapping @@ -81,9 +79,10 @@ def set_ordinal_mention_mapping(self, mapping: Dict[Text, Callable]) -> None: async def get_objects( self, object_type: Text, attributes: List[Dict[Text, Text]], limit: int = 5 ) -> List[Dict[Text, Any]]: - """ - Query the knowledge base for objects of the given type. Restrict the objects - by the provided attributes, if any attributes are given. + """Query the knowledge base for objects of the given type. + + Restrict the objects by the provided attributes, + if any attributes are given. Args: object_type: the object type @@ -97,8 +96,7 @@ async def get_objects( async def get_object( self, object_type: Text, object_identifier: Text ) -> Optional[Dict[Text, Any]]: - """ - Returns the object of the given type that matches the given object identifier. + """Returns the object of the given type and identifier. Args: object_type: the object type @@ -110,16 +108,14 @@ async def get_object( raise NotImplementedError("Method is not implemented.") def get_object_types(self) -> List[Text]: - """ - Returns a list of object types from knowledge base data. - """ + """Returns a list of object types from knowledge base data.""" raise NotImplementedError("Method is not implemented.") class InMemoryKnowledgeBase(KnowledgeBase): def __init__(self, data_file: Text) -> None: - """ - Initialize the in-memory knowledge base. + """Initialize the in-memory knowledge base. + Loads the data from the given data file into memory. Args: @@ -131,9 +127,7 @@ def __init__(self, data_file: Text) -> None: super().__init__() def load(self) -> None: - """ - Load the data from the given file and initialize an in-memory knowledge base. - """ + """Load the data from the file in an in-memory knowledge base.""" try: with open(self.data_file, encoding="utf-8") as f: content = f.read() @@ -151,8 +145,7 @@ def load(self) -> None: def set_representation_function_of_object( self, object_type: Text, representation_function: Callable ) -> None: - """ - Set the representation function of the given object type. + """Set the representation function of the given object type. Args: object_type: the object type @@ -163,8 +156,7 @@ def set_representation_function_of_object( def set_key_attribute_of_object( self, object_type: Text, key_attribute: Text ) -> None: - """ - Set the key attribute of the given object type. + """Set the key attribute of the given object type. Args: object_type: the object type @@ -173,6 +165,13 @@ def set_key_attribute_of_object( self.key_attribute[object_type] = key_attribute async def get_attributes_of_object(self, object_type: Text) -> List[Text]: + """Returns attributes of the object. + + Args: + object_type: the object type + + Returns: list of attributes of an object + """ if object_type not in self.data or not self.data[object_type]: return [] @@ -183,6 +182,16 @@ async def get_attributes_of_object(self, object_type: Text) -> List[Text]: async def get_objects( self, object_type: Text, attributes: List[Dict[Text, Text]], limit: int = 5 ) -> List[Dict[Text, Any]]: + """Returns objects of the given type. + + If attributes are specified, + the objects are filtered by these attributes. + + Args: + object_type: the object type + attributes: list of attributes to filter the objects by + limit: maximum number of objects to return + """ if object_type not in self.data: return [] @@ -207,6 +216,15 @@ async def get_objects( async def get_object( self, object_type: Text, object_identifier: Text ) -> Optional[Dict[Text, Any]]: + """Returns the object of the given type and identifier. + + Args: + object_type: the object type + object_identifier: value of the key attribute or the string + representation of the object + + Returns: the object of interest if found, else None + """ if object_type not in self.data: return None @@ -248,5 +266,9 @@ async def get_object( return objects_of_interest[0] def get_object_types(self) -> List[Text]: - """See parent class docstring.""" + """Returns a list of object types from knowledge base data. + + Returns: + list of object types + """ return list(self.data.keys()) diff --git a/rasa_sdk/knowledge_base/utils.py b/rasa_sdk/knowledge_base/utils.py index 24011ff89..3d72f7d20 100644 --- a/rasa_sdk/knowledge_base/utils.py +++ b/rasa_sdk/knowledge_base/utils.py @@ -18,8 +18,9 @@ def get_object_name( ordinal_mention_mapping: Dict[Text, Callable], use_last_object_mention: bool = True, ) -> Optional[Text]: - """ - Get the name of the object the user referred to. Either the NER detected the + """Get the name of the object the user referred to. + + Either the NER detected the object and stored its name in the corresponding slot (e.g. "PastaBar" is detected as "restaurant") or the user referred to the object by any kind of mention, such as "first one" or "it". @@ -29,10 +30,10 @@ def get_object_name( ordinal_mention_mapping: mapping that maps an ordinal mention to an object in a list use_last_object_mention: if true the last mentioned object is returned if - no other mention could be detected + no other mention could be detected Returns: the name of the actual object (value of key attribute in the - knowledge base) + knowledge base) """ mention = tracker.get_slot(SLOT_MENTION) object_type = tracker.get_slot(SLOT_OBJECT_TYPE) @@ -57,8 +58,7 @@ def get_object_name( def resolve_mention( tracker: "Tracker", ordinal_mention_mapping: Dict[Text, Callable] ) -> Optional[Text]: - """ - Resolve the given mention to the name of the actual object. + """Resolve the given mention to the name of the actual object. Different kind of mentions exist. We distinguish between ordinal mentions and all others for now. @@ -75,9 +75,8 @@ def resolve_mention( ordinal_mention_mapping: mapping that maps an ordinal mention to an object in a list - Returns: name of the actually object + Returns: name of an object """ - mention = tracker.get_slot(SLOT_MENTION) listed_items = tracker.get_slot(SLOT_LISTED_OBJECTS) last_object = tracker.get_slot(SLOT_LAST_OBJECT) @@ -105,7 +104,8 @@ def resolve_mention( def get_attribute_slots( tracker: "Tracker", object_attributes: List[Text] ) -> List[Dict[Text, Text]]: - """ + """Returns a list of attributes which are set in the tracker slots. + If the user mentioned one or multiple attributes of the provided object_type in an utterance, we extract all attribute values from the tracker and put them in a list. The list is used later on to filter a list of objects. @@ -135,8 +135,7 @@ def get_attribute_slots( def reset_attribute_slots( tracker: "Tracker", object_attributes: List[Text] ) -> List[Dict]: - """ - Reset all attribute slots of the current object type. + """Reset all attribute slots of the current object type. If the user is saying something like "Show me all restaurants with Italian cuisine.", the NER should detect "restaurant" as "object_type" and "Italian" as @@ -171,7 +170,8 @@ def match_extracted_entities_to_object_type( tracker: "Tracker", object_types: List[Text], ) -> Optional[Text]: - """ + """Returns the first object type if the last user message contains it. + If the user ask a question about an attribute using an object name and without specifying the object type, then this function searches the corresponding object type. (e.g: when user asks'price range of B&B', this diff --git a/rasa_sdk/tracing/endpoints.py b/rasa_sdk/tracing/endpoints.py index a5520a70d..e0a5b5dc8 100644 --- a/rasa_sdk/tracing/endpoints.py +++ b/rasa_sdk/tracing/endpoints.py @@ -12,9 +12,15 @@ def read_endpoint_config( filename: Text, endpoint_type: Text ) -> Optional["EndpointConfig"]: - """Read an endpoint configuration file from disk and extract one + """Read an endpoint configuration file from disk and extract one config. - config.""" + Args: + filename: the endpoint config file to read + endpoint_type: the type of the endpoint + + Returns: + The endpoint configuration of the passed type if it exists, `None` otherwise. + """ if not filename: return None @@ -28,8 +34,9 @@ def read_endpoint_config( return EndpointConfig.from_dict(content[endpoint_type]) except FileNotFoundError: logger.error( - "Failed to read endpoint configuration " - "from {}. No such file.".format(os.path.abspath(filename)) + "Failed to read endpoint configuration " "from {}. No such file.".format( + os.path.abspath(filename) + ) ) return None diff --git a/rasa_sdk/tracing/tracer_register.py b/rasa_sdk/tracing/tracer_register.py index 2985d5e8e..93b3a9db0 100644 --- a/rasa_sdk/tracing/tracer_register.py +++ b/rasa_sdk/tracing/tracer_register.py @@ -10,13 +10,15 @@ class ActionExecutorTracerRegister(metaclass=Singleton): def register_tracer(self, tracer: Tracer) -> None: """Register an ActionExecutor tracer. + Args: - trace: The tracer to register. + tracer: The tracer to register. """ self.tracer = tracer def get_tracer(self) -> Optional[Tracer]: """Get the ActionExecutor tracer. + Returns: The tracer. """ diff --git a/rasa_sdk/types.py b/rasa_sdk/types.py index 870643ded..b30ce6bea 100644 --- a/rasa_sdk/types.py +++ b/rasa_sdk/types.py @@ -4,9 +4,7 @@ class TrackerState(TypedDict): - """ - A dictionary representation of the state of a conversation. - """ + """A dictionary representation of the state of a conversation.""" # id of the source of the messages sender_id: Text @@ -31,9 +29,7 @@ class TrackerState(TypedDict): class DomainDict(TypedDict): - """ - A dictionary representation of the domain. - """ + """A dictionary representation of the domain.""" intents: List[Dict[Text, Any]] entities: List[Text] @@ -44,9 +40,7 @@ class DomainDict(TypedDict): class ActionCall(TypedDict): - """ - A dictionary representation of an action to be executed. - """ + """A dictionary representation of an action to be executed.""" # the name of the next action to be executed next_action: Optional[Text] diff --git a/rasa_sdk/utils.py b/rasa_sdk/utils.py index b1404b48e..189e98145 100644 --- a/rasa_sdk/utils.py +++ b/rasa_sdk/utils.py @@ -9,7 +9,17 @@ from ruamel.yaml import YAMLError from ruamel.yaml.constructor import DuplicateKeyError -from typing import AbstractSet, Any, Dict, List, Text, Optional, Coroutine, Union +from typing import ( + AbstractSet, + Any, + ClassVar, + Dict, + List, + Text, + Optional, + Coroutine, + Union, +) import rasa_sdk @@ -34,7 +44,13 @@ class Element(dict): """Represents an element in a list of elements in a rich message.""" - __acceptable_keys = ["title", "item_url", "image_url", "subtitle", "buttons"] + __acceptable_keys: ClassVar[List[Text]] = [ + "title", + "item_url", + "image_url", + "subtitle", + "buttons", + ] def __init__(self, *args, **kwargs): """Initializes an element in a list of elements in a rich message.""" @@ -54,7 +70,7 @@ class Button(dict): class Singleton(type): """Singleton metaclass.""" - _instances: Dict[Any, Any] = {} + _instances: ClassVar[Dict[Any, Any]] = {} def __call__(cls, *args: Any, **kwargs: Any) -> Any: """Call the class. @@ -68,6 +84,7 @@ def __call__(cls, *args: Any, **kwargs: Any) -> Any: return cls._instances[cls] + @classmethod def clear(cls) -> None: """Clear the class.""" cls._instances = {} @@ -319,7 +336,7 @@ def update_sanic_log_level() -> None: async def call_potential_coroutine( - coroutine_or_return_value: Union[Any, Coroutine] + coroutine_or_return_value: Union[Any, Coroutine], ) -> Any: """Await if it's a coroutine.""" if asyncio.iscoroutine(coroutine_or_return_value): diff --git a/tests/conftest.py b/tests/conftest.py index 1d82a4380..7be22ab7a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -81,6 +81,7 @@ def run( class MockFormValidationAction(FormValidationAction): def __init__(self) -> None: + """Initializes the mock form validation action.""" self.fail_if_undefined("run") def fail_if_undefined(self, method_name: str) -> None: @@ -108,6 +109,7 @@ def name(self) -> str: class MockValidationAction(ValidationAction): def __init__(self) -> None: + """Initializes the mock validation action.""" self.fail_if_undefined("run") def fail_if_undefined(self, method_name: Text) -> None: diff --git a/tests/knowledge_base/test_actions.py b/tests/knowledge_base/test_actions.py index b962ac70e..c92900b82 100644 --- a/tests/knowledge_base/test_actions.py +++ b/tests/knowledge_base/test_actions.py @@ -20,7 +20,7 @@ def compare_slots(slot_list_1, slot_list_2): assert len(slot_list_2) == len(slot_list_1) for slot_1 in slot_list_1: - slot_2 = list(filter(lambda x: x["name"] == slot_1["name"], slot_list_2))[0] + slot_2 = next(iter(filter(lambda x: x["name"] == slot_1["name"], slot_list_2))) if isinstance(slot_2["value"], list): assert set(slot_1["value"]) == set(slot_2["value"]) diff --git a/tests/test_forms.py b/tests/test_forms.py index 31acd155d..25435c242 100644 --- a/tests/test_forms.py +++ b/tests/test_forms.py @@ -32,6 +32,11 @@ class TestFormValidationAction(FormValidationAction): def __init__(self, form_name: Text = "some_form") -> None: + """Initializes test form validation action. + + Args: + form_name: name of the form + """ self.name_of_form = form_name def name(self) -> Text: diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index 9d706793a..c23cba30e 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -74,7 +74,6 @@ def test_tracker_with_slots(): def test_stack_in_tracker_state( stack_state: Dict[Text, Any], dialogue_stack: List[Dict[Text, Any]] ): - state = {"events": [], "sender_id": "old", "active_loop": {}, **stack_state} tracker = Tracker.from_dict(state) diff --git a/tests/tracing/conftest.py b/tests/tracing/conftest.py index ff6547da7..1d546210e 100644 --- a/tests/tracing/conftest.py +++ b/tests/tracing/conftest.py @@ -31,6 +31,7 @@ def udp_server() -> Generator[socket.socket, None, None]: class CapturingTestSpanExporter(trace_service.TraceServiceServicer): def __init__(self) -> None: + """Initializes the capture test span exporter.""" self.spans: Optional[RepeatedCompositeFieldContainer[ResourceSpans]] = None def Export( diff --git a/tests/tracing/instrumentation/conftest.py b/tests/tracing/instrumentation/conftest.py index 0e2ea108d..e90f275ee 100644 --- a/tests/tracing/instrumentation/conftest.py +++ b/tests/tracing/instrumentation/conftest.py @@ -29,6 +29,8 @@ def previous_num_captured_spans(span_exporter: InMemorySpanExporter) -> int: class MockActionExecutor(ActionExecutor): def __init__(self) -> None: + """Initializes the mock action executor.""" + super().__init__() self.fail_if_undefined("run") def fail_if_undefined(self, method_name: Text) -> None: diff --git a/tests/tracing/instrumentation/test_action_executor.py b/tests/tracing/instrumentation/test_action_executor.py index 95598798a..25eae1ed2 100644 --- a/tests/tracing/instrumentation/test_action_executor.py +++ b/tests/tracing/instrumentation/test_action_executor.py @@ -74,9 +74,7 @@ async def test_tracing_action_executor_run( ) await mock_action_executor.run(action_call) - captured_spans: Sequence[ - ReadableSpan - ] = span_exporter.get_finished_spans() # type: ignore + captured_spans: Sequence[ReadableSpan] = span_exporter.get_finished_spans() # type: ignore num_captured_spans = len(captured_spans) - previous_num_captured_spans assert num_captured_spans == 1 @@ -171,9 +169,7 @@ def test_tracing_action_executor_create_api_response( dispatcher = get_dispatcher() mock_action_executor._create_api_response(events, dispatcher.messages) - captured_spans: Sequence[ - ReadableSpan - ] = span_exporter.get_finished_spans() # type: ignore + captured_spans: Sequence[ReadableSpan] = span_exporter.get_finished_spans() # type: ignore num_captured_spans = len(captured_spans) - previous_num_captured_spans assert num_captured_spans == 1 diff --git a/tests/tracing/instrumentation/test_form_validation_action.py b/tests/tracing/instrumentation/test_form_validation_action.py index 37c849f0d..a2ffb9b4a 100644 --- a/tests/tracing/instrumentation/test_form_validation_action.py +++ b/tests/tracing/instrumentation/test_form_validation_action.py @@ -43,9 +43,7 @@ async def test_form_validation_action_run( await mock_validation_action.run(dispatcher, tracker, {}) - captured_spans: Sequence[ - ReadableSpan - ] = span_exporter.get_finished_spans() # type: ignore + captured_spans: Sequence[ReadableSpan] = span_exporter.get_finished_spans() # type: ignore num_captured_spans = len(captured_spans) - previous_num_captured_spans # includes the child span for `_extract_validation_events` method call @@ -106,9 +104,7 @@ async def test_form_validation_action_extract_validation_events( dispatcher, tracker, {} ) - captured_spans: Sequence[ - ReadableSpan - ] = span_exporter.get_finished_spans() # type: ignore + captured_spans: Sequence[ReadableSpan] = span_exporter.get_finished_spans() # type: ignore num_captured_spans = len(captured_spans) - previous_num_captured_spans assert num_captured_spans == 1 diff --git a/tests/tracing/instrumentation/test_tracing.py b/tests/tracing/instrumentation/test_tracing.py index 67a48d30c..75f14c68f 100644 --- a/tests/tracing/instrumentation/test_tracing.py +++ b/tests/tracing/instrumentation/test_tracing.py @@ -59,9 +59,7 @@ def test_server_webhook_custom_action_is_instrumented( assert response.status == 200 - captured_spans: Sequence[ - ReadableSpan - ] = span_exporter.get_finished_spans() # type: ignore + captured_spans: Sequence[ReadableSpan] = span_exporter.get_finished_spans() # type: ignore num_captured_spans = len(captured_spans) - previous_num_captured_spans assert num_captured_spans == 1 @@ -100,9 +98,7 @@ def test_server_webhook_custom_action_is_not_instrumented( assert response.status == 200 - captured_spans: Sequence[ - ReadableSpan - ] = span_exporter.get_finished_spans() # type: ignore + captured_spans: Sequence[ReadableSpan] = span_exporter.get_finished_spans() # type: ignore num_captured_spans = len(captured_spans) - previous_num_captured_spans assert num_captured_spans == 0 diff --git a/tests/tracing/instrumentation/test_validation_action.py b/tests/tracing/instrumentation/test_validation_action.py index 6e11c1e68..2325a0026 100644 --- a/tests/tracing/instrumentation/test_validation_action.py +++ b/tests/tracing/instrumentation/test_validation_action.py @@ -45,9 +45,7 @@ async def test_validation_action_run( await mock_validation_action.run(dispatcher, tracker, {}) - captured_spans: Sequence[ - ReadableSpan - ] = span_exporter.get_finished_spans() # type: ignore + captured_spans: Sequence[ReadableSpan] = span_exporter.get_finished_spans() # type: ignore num_captured_spans = len(captured_spans) - previous_num_captured_spans assert num_captured_spans == 1 @@ -107,9 +105,7 @@ async def test_validation_action_extract_validation_events( dispatcher, tracker, {} ) - captured_spans: Sequence[ - ReadableSpan - ] = span_exporter.get_finished_spans() # type: ignore + captured_spans: Sequence[ReadableSpan] = span_exporter.get_finished_spans() # type: ignore num_captured_spans = len(captured_spans) - previous_num_captured_spans assert num_captured_spans == 1 diff --git a/tests/tracing/test_utils.py b/tests/tracing/test_utils.py index ef7b82ee1..27bc230a6 100644 --- a/tests/tracing/test_utils.py +++ b/tests/tracing/test_utils.py @@ -43,9 +43,7 @@ def test_get_tracer_provider_returns_none_if_tracing_is_not_configured() -> None def test_get_tracer_provider_returns_provider() -> None: - """Tests that get_tracer_provider returns a TracerProvider - if tracing is configured. - """ + """Tests that get_tracer_provider returns a TracerProvider if tracing is configured.""" # noqa: E501 parser = argparse.ArgumentParser() parser.add_argument("--endpoints", type=str, default=None)