diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 1199f9f70..5305d0a2b 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -13,6 +13,7 @@ This project adheres to `Semantic Versioning`_ starting with version 0.11.0. Added ----- +- Abstract Actions can now be subclassed - add warning in case of mismatched version of rasa_core and rasa_core_sdk Changed diff --git a/rasa_core_sdk/executor.py b/rasa_core_sdk/executor.py index 2e668f424..668559785 100644 --- a/rasa_core_sdk/executor.py +++ b/rasa_core_sdk/executor.py @@ -149,8 +149,11 @@ def register_package(self, package): actions = utils.all_subclasses(Action) for action in actions: + meta = action.__dict__.get('Meta', False) + abstract = getattr(meta, 'abstract', False) if (not action.__module__.startswith("rasa_core.") and - not action.__module__.startswith("rasa_core_sdk.")): + not action.__module__.startswith("rasa_core_sdk.") and + not abstract): self.register_action(action) @staticmethod diff --git a/tests/test_actions.py b/tests/test_actions.py new file mode 100644 index 000000000..328ce4c92 --- /dev/null +++ b/tests/test_actions.py @@ -0,0 +1,49 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from rasa_core_sdk import Action, Tracker +from rasa_core_sdk.events import SlotSet +from rasa_core_sdk.executor import ActionExecutor, CollectingDispatcher + + +class CustomActionBase(Action): + @classmethod + def name(cls): + # Name method needed to test if base action was registered + return "base_action" + + class Meta: + abstract = True + + @staticmethod + def some_common_feature(): + return "test" + + def run(self, dispatcher, tracker, domain): + raise NotImplementedError + + +class CustomAction(CustomActionBase): + + @classmethod + def name(cls): + return "custom_action" + + def run(self, dispatcher, tracker, domain): + return [SlotSet('test', self.some_common_feature())] + + +def test_abstract_action(): + executor = ActionExecutor() + executor.register_package('tests') + assert CustomAction.name() in executor.actions + assert CustomActionBase.name() not in executor.actions + + dispatcher = CollectingDispatcher() + tracker = Tracker('test', {}, {}, [], False, None, {}, 'listen') + domain = {} + + events = CustomAction().run(dispatcher, tracker, domain) + assert events == [SlotSet('test', "test")]