Skip to content

Commit

Permalink
Be more defensive in the prompt view handling (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
zerolab authored Jan 12, 2024
2 parents a94014e + ce57192 commit 54955b9
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 15 deletions.
40 changes: 40 additions & 0 deletions src/wagtail_ai/forms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from django import forms
from django.core.exceptions import ValidationError
from django.utils.translation import gettext_lazy as _


class PromptTextField(forms.CharField):
default_error_messages = {
"required": _(
"No text provided - please enter some text before using AI features."
),
}


class PromptUUIDField(forms.UUIDField):
default_error_messages = {
"required": _("Invalid prompt provided."),
"invalid": _("Invalid prompt provided."),
}


class PromptForm(forms.Form):
text = PromptTextField()
prompt = PromptUUIDField()

def clean_prompt(self):
prompt_uuid = self.cleaned_data["prompt"]
if prompt_uuid.version != 4:
raise ValidationError(
self.fields["prompt"].error_messages["invalid"], code="invalid"
)

return prompt_uuid

def errors_for_json_response(self) -> str:
errors_for_response = []
for _field, errors in self.errors.get_json_data().items():
for error in errors:
errors_for_response.append(error["message"])

return " \n".join(errors_for_response)
24 changes: 10 additions & 14 deletions src/wagtail_ai/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

from django import forms
from django.http import JsonResponse
from django.utils.translation import gettext as _
from django.views.decorators.csrf import csrf_exempt
from wagtail.admin.ui.tables import UpdatedAtColumn
from wagtail.admin.viewsets.model import ModelViewSet

from . import ai, types
from .forms import PromptForm
from .models import Prompt

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -73,24 +75,18 @@ def _append_handler(*, prompt: Prompt, text: str) -> str:


@csrf_exempt
def process(request):
text = request.POST.get("text")
def process(request) -> JsonResponse:
prompt_form = PromptForm(request.POST)

if not text:
if not prompt_form.is_valid():
return JsonResponse(
{
"error": "No text provided - please enter some text before using AI \
features"
},
status=400,
{"error": prompt_form.errors_for_json_response()}, status=400
)

prompt_id = request.POST.get("prompt")

try:
prompt = Prompt.objects.get(uuid=prompt_id)
prompt = Prompt.objects.get(uuid=prompt_form.cleaned_data["prompt"])
except Prompt.DoesNotExist:
return JsonResponse({"error": "Invalid prompt provided"}, status=400)
return JsonResponse({"error": _("Invalid prompt provided.")}, status=400)

handlers = {
Prompt.Method.REPLACE: _replace_handler,
Expand All @@ -100,12 +96,12 @@ def process(request):
handler = handlers[Prompt.Method(prompt.method)]

try:
response = handler(prompt=prompt, text=text)
response = handler(prompt=prompt, text=prompt_form.cleaned_data["text"])
except AIHandlerException as e:
return JsonResponse({"error": str(e)}, status=400)
except Exception:
logger.exception("An unexpected error occurred.")
return JsonResponse({"error": "An unexpected error occurred"}, status=500)
return JsonResponse({"error": _("An unexpected error occurred.")}, status=500)

return JsonResponse({"message": response})

Expand Down
65 changes: 64 additions & 1 deletion tests/test_views.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import uuid

import pytest
from django.urls import reverse
from wagtail_ai.views import PromptEditForm, prompt_viewset
Expand Down Expand Up @@ -35,4 +37,65 @@ def test_prompt_model_admin_viewset_edit_view(client, setup_users, setup_prompt_
assert setup_prompt_object.label in str(response.content)


# TODO add tests for process view
@pytest.mark.django_db
def test_process_view_get_request(client, setup_users):
url = reverse("wagtail_ai:process")

superuser = setup_users
client.force_login(superuser)

response = client.get(url)
assert response.status_code == 400
assert response.json() == {
"error": "No text provided - please enter some text before using AI features. "
"\nInvalid prompt provided."
}


@pytest.mark.django_db
def test_process_view_post_without_text(client, setup_users):
url = reverse("wagtail_ai:process")

superuser = setup_users
client.force_login(superuser)

response = client.post(url, data={})
assert response.status_code == 400
assert response.json() == {
"error": "No text provided - please enter some text before using AI features. "
"\nInvalid prompt provided."
}


@pytest.mark.django_db
@pytest.mark.parametrize(
"prompt", [None, "NOT-A-UUID", str(uuid.uuid1()), str(uuid.uuid4())]
)
def test_process_view_with_bad_prompt_id(client, setup_users, prompt):
url = reverse("wagtail_ai:process")

superuser = setup_users
client.force_login(superuser)

data = {"text": "test"}
if prompt is not None:
data["prompt"] = prompt

response = client.post(url, data=data)
assert response.status_code == 400
assert response.json() == {"error": "Invalid prompt provided."}


@pytest.mark.django_db
def test_process_view_with_correct_prompt(client, setup_users, setup_prompt_object):
url = reverse("wagtail_ai:process")

superuser = setup_users
client.force_login(superuser)

response = client.post(
url, data={"text": "test", "prompt": str(setup_prompt_object.uuid)}
)
assert response.status_code == 200
# correct, the tests default is the echo backend
assert response.json() == {"message": "This is an echo backend: test"}
2 changes: 2 additions & 0 deletions tests/testapp/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,5 @@
},
},
}

FORMS_URLFIELD_ASSUME_HTTPS = True

0 comments on commit 54955b9

Please sign in to comment.