diff --git a/src/wagtail_ai/forms.py b/src/wagtail_ai/forms.py new file mode 100644 index 0000000..ddcc34b --- /dev/null +++ b/src/wagtail_ai/forms.py @@ -0,0 +1,40 @@ +from django import forms +from django.core.exceptions import ValidationError +from django.utils.translation import gettext_lazy as _ + + +class PromptTextField(forms.CharField): + default_error_messages = { + "required": _( + "No text provided - please enter some text before using AI features." + ), + } + + +class PromptUUIDField(forms.UUIDField): + default_error_messages = { + "required": _("Invalid prompt provided."), + "invalid": _("Invalid prompt provided."), + } + + +class PromptForm(forms.Form): + text = PromptTextField() + prompt = PromptUUIDField() + + def clean_prompt(self): + prompt_uuid = self.cleaned_data["prompt"] + if prompt_uuid.version != 4: + raise ValidationError( + self.fields["prompt"].error_messages["invalid"], code="invalid" + ) + + return prompt_uuid + + def errors_for_json_response(self) -> str: + errors_for_response = [] + for _field, errors in self.errors.get_json_data().items(): + for error in errors: + errors_for_response.append(error["message"]) + + return " \n".join(errors_for_response) diff --git a/src/wagtail_ai/views.py b/src/wagtail_ai/views.py index 4dbeef5..ae96b68 100644 --- a/src/wagtail_ai/views.py +++ b/src/wagtail_ai/views.py @@ -3,11 +3,13 @@ from django import forms from django.http import JsonResponse +from django.utils.translation import gettext as _ from django.views.decorators.csrf import csrf_exempt from wagtail.admin.ui.tables import UpdatedAtColumn from wagtail.admin.viewsets.model import ModelViewSet from . import ai, types +from .forms import PromptForm from .models import Prompt logger = logging.getLogger(__name__) @@ -73,24 +75,18 @@ def _append_handler(*, prompt: Prompt, text: str) -> str: @csrf_exempt -def process(request): - text = request.POST.get("text") +def process(request) -> JsonResponse: + prompt_form = PromptForm(request.POST) - if not text: + if not prompt_form.is_valid(): return JsonResponse( - { - "error": "No text provided - please enter some text before using AI \ - features" - }, - status=400, + {"error": prompt_form.errors_for_json_response()}, status=400 ) - prompt_id = request.POST.get("prompt") - try: - prompt = Prompt.objects.get(uuid=prompt_id) + prompt = Prompt.objects.get(uuid=prompt_form.cleaned_data["prompt"]) except Prompt.DoesNotExist: - return JsonResponse({"error": "Invalid prompt provided"}, status=400) + return JsonResponse({"error": _("Invalid prompt provided.")}, status=400) handlers = { Prompt.Method.REPLACE: _replace_handler, @@ -100,12 +96,12 @@ def process(request): handler = handlers[Prompt.Method(prompt.method)] try: - response = handler(prompt=prompt, text=text) + response = handler(prompt=prompt, text=prompt_form.cleaned_data["text"]) except AIHandlerException as e: return JsonResponse({"error": str(e)}, status=400) except Exception: logger.exception("An unexpected error occurred.") - return JsonResponse({"error": "An unexpected error occurred"}, status=500) + return JsonResponse({"error": _("An unexpected error occurred.")}, status=500) return JsonResponse({"message": response}) diff --git a/tests/test_views.py b/tests/test_views.py index 905d796..7df7e62 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -1,3 +1,5 @@ +import uuid + import pytest from django.urls import reverse from wagtail_ai.views import PromptEditForm, prompt_viewset @@ -35,4 +37,65 @@ def test_prompt_model_admin_viewset_edit_view(client, setup_users, setup_prompt_ assert setup_prompt_object.label in str(response.content) -# TODO add tests for process view +@pytest.mark.django_db +def test_process_view_get_request(client, setup_users): + url = reverse("wagtail_ai:process") + + superuser = setup_users + client.force_login(superuser) + + response = client.get(url) + assert response.status_code == 400 + assert response.json() == { + "error": "No text provided - please enter some text before using AI features. " + "\nInvalid prompt provided." + } + + +@pytest.mark.django_db +def test_process_view_post_without_text(client, setup_users): + url = reverse("wagtail_ai:process") + + superuser = setup_users + client.force_login(superuser) + + response = client.post(url, data={}) + assert response.status_code == 400 + assert response.json() == { + "error": "No text provided - please enter some text before using AI features. " + "\nInvalid prompt provided." + } + + +@pytest.mark.django_db +@pytest.mark.parametrize( + "prompt", [None, "NOT-A-UUID", str(uuid.uuid1()), str(uuid.uuid4())] +) +def test_process_view_with_bad_prompt_id(client, setup_users, prompt): + url = reverse("wagtail_ai:process") + + superuser = setup_users + client.force_login(superuser) + + data = {"text": "test"} + if prompt is not None: + data["prompt"] = prompt + + response = client.post(url, data=data) + assert response.status_code == 400 + assert response.json() == {"error": "Invalid prompt provided."} + + +@pytest.mark.django_db +def test_process_view_with_correct_prompt(client, setup_users, setup_prompt_object): + url = reverse("wagtail_ai:process") + + superuser = setup_users + client.force_login(superuser) + + response = client.post( + url, data={"text": "test", "prompt": str(setup_prompt_object.uuid)} + ) + assert response.status_code == 200 + # correct, the tests default is the echo backend + assert response.json() == {"message": "This is an echo backend: test"} diff --git a/tests/testapp/settings.py b/tests/testapp/settings.py index abe3261..ef63a56 100644 --- a/tests/testapp/settings.py +++ b/tests/testapp/settings.py @@ -219,3 +219,5 @@ }, }, } + +FORMS_URLFIELD_ASSUME_HTTPS = True