diff --git a/oarepo_workflows/ext.py b/oarepo_workflows/ext.py index 8c15019..66f306f 100644 --- a/oarepo_workflows/ext.py +++ b/oarepo_workflows/ext.py @@ -14,6 +14,7 @@ import importlib_metadata from invenio_drafts_resources.services.records.uow import ParentRecordCommitOp +from invenio_records_resources.records import Record from invenio_records_resources.services.uow import RecordCommitOp, unit_of_work from oarepo_runtime.datastreams.utils import get_record_service_for_record @@ -28,7 +29,6 @@ from flask import Flask from flask_principal import Identity from invenio_drafts_resources.records import ParentRecord - from invenio_records_resources.records import Record from invenio_records_resources.services.uow import UnitOfWork from oarepo_workflows.base import ( @@ -207,18 +207,37 @@ def default_workflow_events(self) -> dict: """ return self.app.config.get("DEFAULT_WORKFLOW_EVENTS", {}) - def get_workflow(self, record: Record) -> Workflow: + def get_workflow(self, record: Record | dict) -> Workflow: """Get the workflow for a record. :param record: record to get the workflow for :raises MissingWorkflowError: if the workflow is not found :raises InvalidWorkflowError: if the workflow is invalid """ - parent = record.parent # noqa for typing: we do not have a better type for record with parent + if isinstance(record, Record): + try: + parent = record.parent # noqa for typing: we do not have a better type for record with parent + except AttributeError as e: + raise MissingWorkflowError("Record does not have a parent attribute, is it a draft-enabled record?", + record=record) from e + try: + workflow_id = parent.workflow + except AttributeError as e: + raise MissingWorkflowError("Parent record does not have a workflow attribute.", + record=record) from e + else: + try: + parent = record["parent"] + except KeyError as e: + raise MissingWorkflowError("Record does not have a parent attribute.", + record=record) from e + try: + workflow_id = parent["workflow"] + except KeyError as e: + raise MissingWorkflowError("Parent record does not have a workflow attribute.", record=record) from e + try: - return self.record_workflows[parent.workflow] - except AttributeError as e: - raise MissingWorkflowError("Workflow not found.", record=record) from e + return self.record_workflows[workflow_id] except KeyError as e: raise InvalidWorkflowError( f"Workflow {parent.workflow} doesn't exist in the configuration.", diff --git a/oarepo_workflows/requests/policy.py b/oarepo_workflows/requests/policy.py index 4b539cf..c1f6f9f 100644 --- a/oarepo_workflows/requests/policy.py +++ b/oarepo_workflows/requests/policy.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from flask_principal import Identity + from invenio_records_resources.records.api import Record class WorkflowRequestPolicy: @@ -76,7 +77,7 @@ def items(self) -> list[tuple[str, WorkflowRequest]]: return ret def applicable_workflow_requests( - self, identity: Identity, **context: Any + self, identity: Identity, *, record: Record, **context: Any ) -> list[tuple[str, WorkflowRequest]]: """Return a list of applicable requests for the identity and context. @@ -87,6 +88,6 @@ def applicable_workflow_requests( ret = [] for name, request in self.items(): - if request.is_applicable(identity, **context): + if request.is_applicable(identity, record=record, **context): ret.append((name, request)) return ret diff --git a/oarepo_workflows/requests/requests.py b/oarepo_workflows/requests/requests.py index d133dea..aadfc23 100644 --- a/oarepo_workflows/requests/requests.py +++ b/oarepo_workflows/requests/requests.py @@ -26,6 +26,7 @@ from datetime import timedelta from invenio_records_permissions.generators import Generator + from invenio_records_resources.records.api import Record from invenio_requests.customizations.request_types import RequestType from oarepo_workflows.requests.events import WorkflowEvent @@ -76,21 +77,21 @@ def recipient_entity_reference(self, **context: Any) -> dict | None: """ return RecipientEntityReference(self, **context) - def is_applicable(self, identity: Identity, **context: Any) -> bool: + def is_applicable(self, identity: Identity, *, record: Record, **context: Any) -> bool: """Check if the request is applicable for the identity and context (which might include record, community, ...). :param identity: Identity of the requester. :param context: Context of the request that is passed to the requester generators. """ try: - p = Permission(*self.requester_generator.needs(**context)) + p = Permission(*self.requester_generator.needs(record=record, **context)) if not p.needs: return False - p.excludes.update(self.requester_generator.excludes(**context)) + p.excludes.update(self.requester_generator.excludes(record=record, **context)) if not p.allows(identity): return False - if hasattr(self.request_type, "can_create"): - return self.request_type.can_create(identity, **context) + if hasattr(self.request_type, "is_applicable_to"): + return self.request_type.is_applicable_to(identity, topic=record, **context) return True except InvalidConfigurationError: raise @@ -99,7 +100,7 @@ def is_applicable(self, identity: Identity, **context: Any) -> bool: return False @cached_property - def request_type(self) -> type[RequestType]: + def request_type(self) -> RequestType: """Return the request type for the workflow request.""" if self._request_type is None: raise InvalidConfigurationError(