From 3076a14c6b3a4c3696452a384c41670997bf20ae Mon Sep 17 00:00:00 2001 From: DonHaul Date: Tue, 21 May 2024 17:04:53 +0200 Subject: [PATCH] adding functionality for restart workflows --- backoffice/workflows/api/views.py | 49 ++++++++++++++++++++- backoffice/workflows/constants.py | 13 ++++++ backoffice/workflows/tests/test_views.py | 56 +++++++++++++++++++++++- backoffice/workflows/urls.py | 10 +++++ poetry.lock | 6 +-- 5 files changed, 129 insertions(+), 5 deletions(-) create mode 100644 backoffice/workflows/urls.py diff --git a/backoffice/workflows/api/views.py b/backoffice/workflows/api/views.py index 69f473d3..48fe8a39 100644 --- a/backoffice/workflows/api/views.py +++ b/backoffice/workflows/api/views.py @@ -1,3 +1,4 @@ +import requests from django.shortcuts import get_object_or_404 from django_elasticsearch_dsl_drf.viewsets import BaseDocumentViewSet from rest_framework import status, viewsets @@ -9,7 +10,7 @@ from backoffice.workflows.documents import WorkflowDocument from backoffice.workflows.models import Workflow, WorkflowTicket -from ..constants import WORKFLOW_DAG, ResolutionDags +from ..constants import AUTHOR_DAGS, WORKFLOW_DAG, ResolutionDags from .serializers import ( AuthorResolutionSerializer, WorkflowDocumentSerializer, @@ -98,6 +99,52 @@ def resolve(self, request, pk=None): ResolutionDags[serializer.validated_data["value"]].label, pk, extra_data ) + @action(detail=True, methods=["post"]) + def restart(self, request, pk=None): + restart_current_task = request.data.get("restart_current_task", False) + + workflow = Workflow.objects.get(id=pk) + + data = {"dry_run": False, "dag_run_id": pk, "reset_dag_runs": True} + + executed_dags_for_workflow = {} + # find dags that were executed + for dag_id in AUTHOR_DAGS[workflow.workflow_type]: + response = requests.get( + f"{airflow_utils.AIRFLOW_BASE_URL}/api/v1/dags/{dag_id}/dagRuns/{pk}", + json=data, + headers=airflow_utils.AIRFLOW_HEADERS, + ) + if response.status_code == status.HTTP_200_OK: + executed_dags_for_workflow[dag_id] = response.content + + # assumes current task is one of the failed tasks + if restart_current_task: + + data = {"dry_run": False, "dag_run_id": pk, "reset_dag_runs": False, "only_failed": True} + + response = requests.post( + f"{airflow_utils.AIRFLOW_BASE_URL}/api/v1/dags/{dag_id}/clearTaskInstances", + json=data, + headers=airflow_utils.AIRFLOW_HEADERS, + ) + if response.status_code != 200: + return Response({"error": "Failed to restart task"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + return Response(response.json(), status=status.HTTP_200_OK) + + else: + # delete every executed_dag for this workflow + for i, dag_id in enumerate(executed_dags_for_workflow): + + # delete all executions of workflow + response = requests.delete( + f"{airflow_utils.AIRFLOW_BASE_URL}/api/v1/dags/{dag_id}/dagRuns/{pk}", + headers=airflow_utils.AIRFLOW_HEADERS, + ) + if response.status_code != 204: + return Response({"error": "Failed to restart"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + class WorkflowDocumentView(BaseDocumentViewSet): def __init__(self, *args, **kwargs): diff --git a/backoffice/workflows/constants.py b/backoffice/workflows/constants.py index 353b6cd4..cfadccf8 100644 --- a/backoffice/workflows/constants.py +++ b/backoffice/workflows/constants.py @@ -41,3 +41,16 @@ class WorkflowType(models.TextChoices): class ResolutionDags(models.TextChoices): accept = "accept", "author_create_approved_dag" reject = "reject", "author_create_rejected_dag" + + +# author dags for each workflow type +AUTHOR_DAGS = { + WorkflowType.HEP_CREATE: "", + WorkflowType.HEP_UPDATE: "", + WorkflowType.AUTHOR_CREATE: ( + "author_create_initialization_dag", + "author_create_approved_dag", + "author_create_rejected_dag", + ), + WorkflowType.AUTHOR_UPDATE: ("author_update_dag",), +} diff --git a/backoffice/workflows/tests/test_views.py b/backoffice/workflows/tests/test_views.py index 4ca71b26..ae87774a 100644 --- a/backoffice/workflows/tests/test_views.py +++ b/backoffice/workflows/tests/test_views.py @@ -11,10 +11,10 @@ from backoffice.workflows.api.serializers import WorkflowTicketSerializer from backoffice.workflows.constants import StatusChoices -from backoffice.workflows.models import WorkflowTicket User = get_user_model() Workflow = apps.get_model(app_label="workflows", model_name="Workflow") +WorkflowTicket = apps.get_model(app_label="workflows", model_name="WorkflowTicket") class BaseTransactionTestCase(TransactionTestCase): @@ -210,6 +210,12 @@ class TestAuthorWorkflowViewSet(BaseTransactionTestCase): reset_sequences = True fixtures = ["backoffice/fixtures/groups.json"] + def setUp(self): + super().setUp() + self.workflow = Workflow.objects.create( + data={}, status="running", core=True, is_update=False, workflow_type="AUTHOR_CREATE" + ) + @patch("backoffice.workflows.airflow_utils.requests.post") def test_create_author(self, mock_post): self.api_client.force_authenticate(user=self.curator) @@ -266,3 +272,51 @@ def test_reject_author(self, mock_post): ) self.assertEqual(response.status_code, 200) + + @patch("backoffice.workflows.airflow_utils.requests.post") + def test_restart_full_dagrun(self, mock_post): + + mock_response = mock_post.return_value + mock_response.status_code = status.HTTP_200_OK + mock_response.json.return_value = {"key": "value"} + + self.api_client.force_authenticate(user=self.curator) + url = reverse( + "api:workflows-authors-restart", + kwargs={"pk": self.workflow.id}, + ) + + response = self.api_client.post(url) + + self.assertEqual(response.status_code, 200) + + @patch("backoffice.workflows.airflow_utils.requests.post") + def test_restart_a_task(self, mock_post): + + mock_response = mock_post.return_value + mock_response.status_code = status.HTTP_200_OK + mock_response.json.return_value = {"key": "value"} + + self.api_client.force_authenticate(user=self.curator) + url = reverse( + "api:workflows-authors-restart", + kwargs={"pk": self.workflow.id}, + ) + + response = self.api_client.post(url, json={"task_ids": ["set_workflow_status_to_running"]}) + self.assertEqual(response.status_code, 200) + + @patch("backoffice.workflows.airflow_utils.requests.post") + def test_restart_with_params(self, mock_post): + + mock_response = mock_post.return_value + mock_response.status_code = status.HTTP_200_OK + mock_response.json.return_value = {"key": "value"} + self.api_client.force_authenticate(user=self.curator) + url = reverse( + "api:workflows-authors-restart", + kwargs={"pk": self.workflow.id}, + ) + + response = self.api_client.post(url, json={"params": {"workflow_id": self.workflow.id}}) + self.assertEqual(response.status_code, 200) diff --git a/backoffice/workflows/urls.py b/backoffice/workflows/urls.py new file mode 100644 index 00000000..2d25d669 --- /dev/null +++ b/backoffice/workflows/urls.py @@ -0,0 +1,10 @@ +from django.urls import path +from . import views + +urlpatterns = [ + path( + "workflows///restart", + views.RestartWorkflowView.as_view(), + name="workflow_restart", + ), +] diff --git a/poetry.lock b/poetry.lock index 97316303..a394fa5b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "amqp" @@ -1146,7 +1146,7 @@ files = [ django = "*" django-stubs-ext = ">=4.2.2" mypy = [ - {version = ">=1.0.0"}, + {version = ">=1.0.0", optional = true, markers = "extra != \"compatible-mypy\""}, {version = "==1.5.*", optional = true, markers = "extra == \"compatible-mypy\""}, ] types-pytz = "*" @@ -1238,7 +1238,7 @@ files = [ [package.dependencies] django-stubs = ">=4.2.4" mypy = [ - {version = ">=0.991"}, + {version = ">=0.991", optional = true, markers = "extra != \"compatible-mypy\""}, {version = "==1.5.*", optional = true, markers = "extra == \"compatible-mypy\""}, ] requests = ">=2.0.0"