From 0fe0c9250654ba55364f3a1fbc2263d62ec7b1a3 Mon Sep 17 00:00:00 2001 From: Aleksandr Movchan Date: Thu, 12 Dec 2024 11:30:00 +0000 Subject: [PATCH] refactor file upload handling and introduce new test for file upload endpoint --- aana/api/api_generation.py | 2 +- aana/core/models/image.py | 3 - aana/core/models/video.py | 2 - aana/tests/units/test_api_generation.py | 46 ------------- aana/tests/units/test_app_upload.py | 90 +++++++++++++++++++++++++ 5 files changed, 91 insertions(+), 52 deletions(-) create mode 100644 aana/tests/units/test_app_upload.py diff --git a/aana/api/api_generation.py b/aana/api/api_generation.py index bd74e4c6..3df92480 100644 --- a/aana/api/api_generation.py +++ b/aana/api/api_generation.py @@ -4,7 +4,7 @@ from inspect import isasyncgenfunction from typing import Annotated, Any, get_origin -from fastapi import FastAPI, File, Form, Query, Request, UploadFile +from fastapi import FastAPI, Form, Query, Request from fastapi.responses import StreamingResponse from pydantic import ConfigDict, Field, ValidationError, create_model from pydantic.main import BaseModel diff --git a/aana/core/models/image.py b/aana/core/models/image.py index 78aca604..c1d10b51 100644 --- a/aana/core/models/image.py +++ b/aana/core/models/image.py @@ -13,11 +13,8 @@ BaseModel, ConfigDict, Field, - ValidationError, model_validator, ) -from pydantic_core import InitErrorDetails -from starlette.datastructures import UploadFile as StarletteUploadFile from typing_extensions import Self from aana.configs.settings import settings diff --git a/aana/core/models/video.py b/aana/core/models/video.py index 1197b644..60e59a52 100644 --- a/aana/core/models/video.py +++ b/aana/core/models/video.py @@ -15,10 +15,8 @@ BaseModel, ConfigDict, Field, - ValidationError, model_validator, ) -from pydantic_core import InitErrorDetails from typing_extensions import Self from aana.core.models.base import BaseListModel diff --git a/aana/tests/units/test_api_generation.py b/aana/tests/units/test_api_generation.py index 7f072695..e0beeda2 100644 --- a/aana/tests/units/test_api_generation.py +++ b/aana/tests/units/test_api_generation.py @@ -6,7 +6,6 @@ from pydantic import BaseModel, ConfigDict, Field from aana.api.api_generation import Endpoint -from aana.exceptions.runtime import UploadedFileNotFound class InputModel(BaseModel): @@ -16,24 +15,6 @@ class InputModel(BaseModel): model_config = ConfigDict(extra="forbid") -class FileUploadModel(BaseModel): - """Model for a file upload input.""" - - content: str | None = Field( - None, - description="The name of the file to upload.", - ) - _file: bytes | None = None - - def set_files(self, files: dict[str, bytes]): - """Set files.""" - if self.content and not self._file: - raise UploadedFileNotFound(self.content) - self._file = files[self.content] - - model_config = ConfigDict(extra="forbid") - - class TestEndpointOutput(TypedDict): """The output of the test endpoint.""" @@ -48,24 +29,6 @@ async def run(self, input_data: InputModel) -> TestEndpointOutput: return {"output": input_data.input} -class TestFileUploadEndpoint(Endpoint): - """Test endpoint for file uploads.""" - - async def run(self, input_data: FileUploadModel) -> TestEndpointOutput: - """Run the endpoint.""" - return {"output": "file uploaded"} - - -class TestMultipleFileUploadEndpoint(Endpoint): - """Test endpoint for multiple file uploads.""" - - async def run( - self, input_data: FileUploadModel, input_data2: FileUploadModel - ) -> TestEndpointOutput: - """Run the endpoint.""" - return {"output": "file uploaded"} - - class TestEndpointMissingReturn(Endpoint): """Test endpoint for get_response_model with missing return type.""" @@ -121,15 +84,6 @@ def test_get_response_model(): assert ResponseModel.model_fields["output"].annotation == str -# def test_get_file_upload_field_multiple_file_uploads(): -# """Test the __get_file_upload_field function with multiple file uploads.""" -# endpoint = TestMultipleFileUploadEndpoint( -# name="test_endpoint", -# summary="Test endpoint", -# path="/test_endpoint", -# ) - - def test_get_response_model_missing_return(): """Test the get_response_model function with missing return type.""" endpoint = TestEndpointMissingReturn( diff --git a/aana/tests/units/test_app_upload.py b/aana/tests/units/test_app_upload.py new file mode 100644 index 00000000..536c5966 --- /dev/null +++ b/aana/tests/units/test_app_upload.py @@ -0,0 +1,90 @@ +# ruff: noqa: S101, S113 +import io +import json +from typing import TypedDict + +import requests +from pydantic import BaseModel, ConfigDict, Field + +from aana.api.api_generation import Endpoint +from aana.exceptions.runtime import UploadedFileNotFound + + +class FileUploadModel(BaseModel): + """Model for a file upload input.""" + + content: str | None = Field( + None, + description="The name of the file to upload.", + ) + _file: bytes | None = None + + def set_files(self, files: dict[str, bytes]): + """Set files.""" + if self.content: + if self.content not in files: + raise UploadedFileNotFound(self.content) + self._file = files[self.content] + + model_config = ConfigDict(extra="forbid") + + +class FileUploadEndpointOutput(TypedDict): + """The output of the file upload endpoint.""" + + text: str + + +class FileUploadEndpoint(Endpoint): + """File upload endpoint.""" + + async def run(self, file: FileUploadModel) -> FileUploadEndpointOutput: + """Upload a file. + + Args: + file (FileUploadModel): The file to upload + + Returns: + FileUploadEndpointOutput: The uploaded file + """ + file = file._file + return {"text": file.decode()} + + +deployments = [] + +endpoints = [ + { + "name": "file_upload", + "path": "/file_upload", + "summary": "Upload a file", + "endpoint_cls": FileUploadEndpoint, + } +] + + +def test_file_upload_app(create_app): + """Test the app with a file upload endpoint.""" + aana_app = create_app(deployments, endpoints) + + port = aana_app.port + route_prefix = "" + + # Check that the server is ready + response = requests.get(f"http://localhost:{port}{route_prefix}/api/ready") + assert response.status_code == 200 + assert response.json() == {"ready": True} + + # Test lowercase endpoint + # data = {"content": "file.txt"} + data = {"file": {"content": "file.txt"}} + file = b"Hello world! This is a test." + files = {"file.txt": io.BytesIO(file)} + response = requests.post( + f"http://localhost:{port}{route_prefix}/file_upload", + data={"body": json.dumps(data)}, + files=files, + ) + assert response.status_code == 200, response.text + text = response.json().get("text") + assert text == file.decode()