Skip to content

Commit

Permalink
refactor file upload handling and introduce new test for file upload …
Browse files Browse the repository at this point in the history
…endpoint
  • Loading branch information
Aleksandr Movchan committed Dec 12, 2024
1 parent b89a08f commit 0fe0c92
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 52 deletions.
2 changes: 1 addition & 1 deletion aana/api/api_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from inspect import isasyncgenfunction
from typing import Annotated, Any, get_origin

from fastapi import FastAPI, File, Form, Query, Request, UploadFile
from fastapi import FastAPI, Form, Query, Request
from fastapi.responses import StreamingResponse
from pydantic import ConfigDict, Field, ValidationError, create_model
from pydantic.main import BaseModel
Expand Down
3 changes: 0 additions & 3 deletions aana/core/models/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,8 @@
BaseModel,
ConfigDict,
Field,
ValidationError,
model_validator,
)
from pydantic_core import InitErrorDetails
from starlette.datastructures import UploadFile as StarletteUploadFile
from typing_extensions import Self

from aana.configs.settings import settings
Expand Down
2 changes: 0 additions & 2 deletions aana/core/models/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@
BaseModel,
ConfigDict,
Field,
ValidationError,
model_validator,
)
from pydantic_core import InitErrorDetails
from typing_extensions import Self

from aana.core.models.base import BaseListModel
Expand Down
46 changes: 0 additions & 46 deletions aana/tests/units/test_api_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from pydantic import BaseModel, ConfigDict, Field

from aana.api.api_generation import Endpoint
from aana.exceptions.runtime import UploadedFileNotFound


class InputModel(BaseModel):
Expand All @@ -16,24 +15,6 @@ class InputModel(BaseModel):
model_config = ConfigDict(extra="forbid")


class FileUploadModel(BaseModel):
"""Model for a file upload input."""

content: str | None = Field(
None,
description="The name of the file to upload.",
)
_file: bytes | None = None

def set_files(self, files: dict[str, bytes]):
"""Set files."""
if self.content and not self._file:
raise UploadedFileNotFound(self.content)
self._file = files[self.content]

model_config = ConfigDict(extra="forbid")


class TestEndpointOutput(TypedDict):
"""The output of the test endpoint."""

Expand All @@ -48,24 +29,6 @@ async def run(self, input_data: InputModel) -> TestEndpointOutput:
return {"output": input_data.input}


class TestFileUploadEndpoint(Endpoint):
"""Test endpoint for file uploads."""

async def run(self, input_data: FileUploadModel) -> TestEndpointOutput:
"""Run the endpoint."""
return {"output": "file uploaded"}


class TestMultipleFileUploadEndpoint(Endpoint):
"""Test endpoint for multiple file uploads."""

async def run(
self, input_data: FileUploadModel, input_data2: FileUploadModel
) -> TestEndpointOutput:
"""Run the endpoint."""
return {"output": "file uploaded"}


class TestEndpointMissingReturn(Endpoint):
"""Test endpoint for get_response_model with missing return type."""

Expand Down Expand Up @@ -121,15 +84,6 @@ def test_get_response_model():
assert ResponseModel.model_fields["output"].annotation == str


# def test_get_file_upload_field_multiple_file_uploads():
# """Test the __get_file_upload_field function with multiple file uploads."""
# endpoint = TestMultipleFileUploadEndpoint(
# name="test_endpoint",
# summary="Test endpoint",
# path="/test_endpoint",
# )


def test_get_response_model_missing_return():
"""Test the get_response_model function with missing return type."""
endpoint = TestEndpointMissingReturn(
Expand Down
90 changes: 90 additions & 0 deletions aana/tests/units/test_app_upload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# ruff: noqa: S101, S113
import io
import json
from typing import TypedDict

import requests
from pydantic import BaseModel, ConfigDict, Field

from aana.api.api_generation import Endpoint
from aana.exceptions.runtime import UploadedFileNotFound


class FileUploadModel(BaseModel):
"""Model for a file upload input."""

content: str | None = Field(
None,
description="The name of the file to upload.",
)
_file: bytes | None = None

def set_files(self, files: dict[str, bytes]):
"""Set files."""
if self.content:
if self.content not in files:
raise UploadedFileNotFound(self.content)
self._file = files[self.content]

model_config = ConfigDict(extra="forbid")


class FileUploadEndpointOutput(TypedDict):
"""The output of the file upload endpoint."""

text: str


class FileUploadEndpoint(Endpoint):
"""File upload endpoint."""

async def run(self, file: FileUploadModel) -> FileUploadEndpointOutput:
"""Upload a file.
Args:
file (FileUploadModel): The file to upload
Returns:
FileUploadEndpointOutput: The uploaded file
"""
file = file._file
return {"text": file.decode()}


deployments = []

endpoints = [
{
"name": "file_upload",
"path": "/file_upload",
"summary": "Upload a file",
"endpoint_cls": FileUploadEndpoint,
}
]


def test_file_upload_app(create_app):
"""Test the app with a file upload endpoint."""
aana_app = create_app(deployments, endpoints)

port = aana_app.port
route_prefix = ""

# Check that the server is ready
response = requests.get(f"http://localhost:{port}{route_prefix}/api/ready")
assert response.status_code == 200
assert response.json() == {"ready": True}

# Test lowercase endpoint
# data = {"content": "file.txt"}
data = {"file": {"content": "file.txt"}}
file = b"Hello world! This is a test."
files = {"file.txt": io.BytesIO(file)}
response = requests.post(
f"http://localhost:{port}{route_prefix}/file_upload",
data={"body": json.dumps(data)},
files=files,
)
assert response.status_code == 200, response.text
text = response.json().get("text")
assert text == file.decode()

0 comments on commit 0fe0c92

Please sign in to comment.