Skip to content
This repository has been archived by the owner on Nov 21, 2024. It is now read-only.

Commit

Permalink
user actions: adding functionality for restart workflows
Browse files Browse the repository at this point in the history
  • Loading branch information
DonHaul committed Jul 17, 2024
1 parent 9772db3 commit b255a7b
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 4 deletions.
53 changes: 52 additions & 1 deletion backoffice/workflows/api/views.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import requests
from django.shortcuts import get_object_or_404
from django_elasticsearch_dsl_drf.viewsets import BaseDocumentViewSet
from rest_framework import status, viewsets
Expand All @@ -9,7 +10,7 @@
from backoffice.workflows.documents import WorkflowDocument
from backoffice.workflows.models import Workflow, WorkflowTicket

from ..constants import WORKFLOW_DAG, ResolutionDags
from ..constants import AUTHOR_DAGS, WORKFLOW_DAG, ResolutionDags
from .serializers import (
AuthorResolutionSerializer,
WorkflowDocumentSerializer,
Expand Down Expand Up @@ -98,6 +99,56 @@ def resolve(self, request, pk=None):
ResolutionDags[serializer.validated_data["value"]].label, pk, extra_data
)

@action(detail=True, methods=["post"])
def restart(self, request, pk=None):

params = request.data.get("params", None)
restart_current_task = request.data.get("restart_current_task", False)

workflow = Workflow.objects.get(id=pk)

data = {"dry_run": False, "dag_run_id": pk, "reset_dag_runs": True}

executed_dags_for_workflow = {}
# find dags that were executed
for dag_id in AUTHOR_DAGS[workflow.workflow_type]:
response = requests.get(
f"{airflow_utils.AIRFLOW_BASE_URL}/api/v1/dags/{dag_id}/dagRuns/{pk}",
json=data,
headers=airflow_utils.AIRFLOW_HEADERS,
)
if response.status_code == status.HTTP_200_OK:
executed_dags_for_workflow[dag_id] = response.content

# assumes current task is one of the failed tasks
if restart_current_task:

data = {"dry_run": False, "dag_run_id": pk, "reset_dag_runs": False, "only_failed": True}

response = requests.post(
f"{airflow_utils.AIRFLOW_BASE_URL}/api/v1/dags/{dag_id}/clearTaskInstances",
json=data,
headers=airflow_utils.AIRFLOW_HEADERS,
)
if response.status_code != 200:
return Response({"error": "Failed to restart task"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)

return Response(response.json(), status=status.HTTP_200_OK)

else:
# delete every executed_dag for this workflow
for i, dag_id in enumerate(executed_dags_for_workflow):

# delete all executions of workflow
response = requests.delete(
f"{airflow_utils.AIRFLOW_BASE_URL}/api/v1/dags/{dag_id}/dagRuns/{pk}",
headers=airflow_utils.AIRFLOW_HEADERS,
)

return airflow_utils.trigger_airflow_dag(WORKFLOW_DAG[workflow.workflow_type], pk, params)

return Response({"error": "Failed to restart"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)


class WorkflowDocumentView(BaseDocumentViewSet):
def __init__(self, *args, **kwargs):
Expand Down
13 changes: 13 additions & 0 deletions backoffice/workflows/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,16 @@ class WorkflowType(models.TextChoices):
class ResolutionDags(models.TextChoices):
accept = "accept", "author_create_approved_dag"
reject = "reject", "author_create_rejected_dag"


# author dags for each workflow type
AUTHOR_DAGS = {
WorkflowType.HEP_CREATE: "",
WorkflowType.HEP_UPDATE: "",
WorkflowType.AUTHOR_CREATE: (
"author_create_initialization_dag",
"author_create_approved_dag",
"author_create_rejected_dag",
),
WorkflowType.AUTHOR_UPDATE: ("author_update_dag",),
}
61 changes: 61 additions & 0 deletions backoffice/workflows/tests/test_views.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import contextmanager
from unittest.mock import patch

from django.apps import apps
Expand Down Expand Up @@ -210,6 +211,12 @@ class TestAuthorWorkflowViewSet(BaseTransactionTestCase):
reset_sequences = True
fixtures = ["backoffice/fixtures/groups.json"]

def setUp(self):
super().setUp()
self.workflow = Workflow.objects.create(
data={}, status="running", core=True, is_update=False, workflow_type="AUTHOR_CREATE"
)

@patch("backoffice.workflows.airflow_utils.requests.post")
def test_create_author(self, mock_post):
self.api_client.force_authenticate(user=self.curator)
Expand Down Expand Up @@ -266,3 +273,57 @@ def test_reject_author(self, mock_post):
)

self.assertEqual(response.status_code, 200)

def test_restart_full_dagrun(self):

self.api_client.force_authenticate(user=self.curator)
url = reverse(
"api:workflows-authors-restart",
kwargs={"pk": self.workflow.id},
)
with patch_requests() as (mock_post, mock_get, mock_delete):
response = self.api_client.post(url)

self.assertEqual(response.status_code, 200)

def test_restart_a_task(self):

self.api_client.force_authenticate(user=self.curator)
url = reverse(
"api:workflows-authors-restart",
kwargs={"pk": self.workflow.id},
)
with patch_requests() as (mock_post, mock_get, mock_delete):
response = self.api_client.post(url, json={"task_ids": ["set_workflow_status_to_running"]})
self.assertEqual(response.status_code, 200)

def test_restart_with_params(self):

self.api_client.force_authenticate(user=self.curator)
url = reverse(
"api:workflows-authors-restart",
kwargs={"pk": self.workflow.id},
)
with patch_requests() as (mock_post, mock_get, mock_delete):
response = self.api_client.post(url, json={"params": {"workflow_id": self.workflow.id}})
self.assertEqual(response.status_code, 200)


@contextmanager
def patch_requests():
with patch("requests.post") as mock_post, patch("requests.get") as mock_get, patch(
"requests.delete"
) as mock_delete:

# Configure the mock for requests.post
mock_post.return_value.status_code = 200
mock_post.return_value.json.return_value = {"key": "value"}

# Configure the mock for requests.get
mock_get.return_value.status_code = 200
mock_get.return_value.json.return_value = {"data": "some_data"}

# Configure the mock for requests.delete
mock_delete.return_value.status_code = 204

yield mock_post, mock_get, mock_delete
10 changes: 10 additions & 0 deletions backoffice/workflows/urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from django.urls import path
from . import views

urlpatterns = [
path(
"workflows/<str:dag_id>/<str:workflow_id>/restart",
views.RestartWorkflowView.as_view(),
name="workflow_restart",
),
]
6 changes: 3 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit b255a7b

Please sign in to comment.