Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor File Uploads #216

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,75 @@ aana_app.deploy() # Deploys the application.

All you need to do is define the deployments and endpoints you want to use in your application, and Aana SDK will take care of the rest.


## API

Aana SDK uses form data for API requests, which allows sending both binary data and structured fields in a single request. The request body is sent as a JSON string in the `body` field, and any binary data is sent as files.

### Making API Requests

You can send requests to the SDK endpoints with only structured data or a combination of structured data and binary data.

#### Only Structured Data
When your request includes only structured data, you can send it as a JSON string in the `body` field.

- **cURL Example:**
```bash
curl http://127.0.0.1:8000/endpoint \
-F body='{"input": "data", "param": "value"}'
```

- **Python Example:**
```python
import json, requests

url = "http://127.0.0.1:8000/endpoint"
body = {
"input": "data",
"param": "value"
}

response = requests.post(
url,
data={"body": json.dumps(body)}
)

print(response.json())
```

#### With Binary Data
When your request includes binary files (images, audio, etc.), you can send them as files in the request and include the names of the files in the `body` field as a reference.

For example, if you want to send an image, you can use [`aana.core.models.image.ImageInput`](https://mobiusml.github.io/aana_sdk/reference/models/media/#aana.core.models.ImageInput) as the input type that supports binary data upload. The `content` field in the input type should be set to the name of the file you are sending.

- **cURL Example:**
```bash
curl http://127.0.0.1:8000/process_images \
-H "Content-Type: multipart/form-data" \
-F body='{"image": {"content": "file1"}}' \
-F file1="@image.jpeg"
```

- **Python Example:**
```python
import json, requests

url = "http://127.0.0.1:8000/process_images"
body = {
"image": {"content": "file1"}
}
with open("image.jpeg", "rb") as file:
files = {"file1": file}

response = requests.post(
url,
data={"body": json.dumps(body)},
files=files
)

print(response.text)
```

## Serve Config Files

The [Serve Config Files](https://docs.ray.io/en/latest/serve/production-guide/config.html#serve-config-files) is the recommended way to deploy and update your applications in production. Aana SDK provides a way to build the Serve Config Files for the Aana applications. See the [Serve Config Files documentation](https://mobiusml.github.io/aana_sdk/pages/serve_config_files/) on how to build and deploy the applications using the Serve Config Files.
Expand Down
78 changes: 15 additions & 63 deletions aana/api/api_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,18 @@
from inspect import isasyncgenfunction
from typing import Annotated, Any, get_origin

from fastapi import FastAPI, File, Form, Query, UploadFile
from fastapi import FastAPI, Form, Query, Request
from fastapi.responses import StreamingResponse
from pydantic import ConfigDict, Field, ValidationError, create_model
from pydantic.main import BaseModel
from starlette.datastructures import UploadFile as StarletteUploadFile

from aana.api.event_handlers.event_handler import EventHandler
from aana.api.event_handlers.event_manager import EventManager
from aana.api.exception_handler import custom_exception_handler
from aana.api.responses import AanaJSONResponse
from aana.configs.settings import settings as aana_settings
from aana.core.models.exception import ExceptionResponseModel
from aana.exceptions.runtime import (
MultipleFileUploadNotAllowed,
)
from aana.storage.repository.task import TaskRepository
from aana.storage.session import get_session

Expand All @@ -32,19 +30,6 @@ def get_default_values(func):
}


@dataclass
class FileUploadField:
"""Class used to represent a file upload field.

Attributes:
name (str): Name of the field.
description (str): Description of the field.
"""

name: str
description: str


@dataclass
class Endpoint:
"""Class used to represent an endpoint.
Expand Down Expand Up @@ -105,14 +90,12 @@ def register(
RequestModel = self.get_request_model()
ResponseModel = self.get_response_model()

file_upload_field = self.__get_file_upload_field()
if self.event_handlers:
for handler in self.event_handlers:
event_manager.register_handler_for_events(handler, [self.path])

route_func = self.__create_endpoint_func(
RequestModel=RequestModel,
file_upload_field=file_upload_field,
event_manager=event_manager,
)

Expand Down Expand Up @@ -229,39 +212,6 @@ def get_response_model(self) -> type[BaseModel]:
model_name, **output_fields, __config__=ConfigDict(extra="forbid")
)

def __get_file_upload_field(self) -> FileUploadField | None:
"""Get the file upload field for the endpoint.

Returns:
Optional[FileUploadField]: File upload field or None if not found.

Raises:
MultipleFileUploadNotAllowed: If multiple inputs require file upload.
"""
file_upload_field = None
for arg_name, arg_type in self.run.__annotations__.items():
if arg_name == "return":
continue

# check if pydantic model has file_upload field and it's set to True
if isinstance(arg_type, type) and issubclass(arg_type, BaseModel):
file_upload_enabled = arg_type.model_config.get("file_upload", False)
file_upload_description = arg_type.model_config.get(
"file_upload_description", ""
)
else:
file_upload_enabled = False
file_upload_description = ""

if file_upload_enabled and file_upload_field is None:
file_upload_field = FileUploadField(
name=arg_name, description=file_upload_description
)
elif file_upload_enabled and file_upload_field is not None:
# raise an exception if multiple inputs require file upload
raise MultipleFileUploadNotAllowed(arg_name)
return file_upload_field

@classmethod
def is_streaming_response(cls) -> bool:
"""Check if the endpoint returns a streaming response.
Expand All @@ -274,15 +224,14 @@ def is_streaming_response(cls) -> bool:
def __create_endpoint_func( # noqa: C901
self,
RequestModel: type[BaseModel],
file_upload_field: FileUploadField | None = None,
event_manager: EventManager | None = None,
) -> Callable:
"""Create a function for routing an endpoint."""
# Copy path to a bound variable so we don't retain an external reference
bound_path = self.path

async def route_func_body( # noqa: C901
body: str, files: list[UploadFile] | None = None, defer=False
body: str, files: dict[str, bytes] | None = None, defer=False
):
if not self.initialized:
await self.initialize()
Expand All @@ -294,9 +243,11 @@ async def route_func_body( # noqa: C901
data = RequestModel.model_validate_json(body)

# if the input requires file upload, add the files to the data
if file_upload_field and files:
files_as_bytes = [await file.read() for file in files]
getattr(data, file_upload_field.name).set_files(files_as_bytes)
if files:
for field_name in data.model_fields:
field_value = getattr(data, field_name)
if hasattr(field_value, "set_files"):
field_value.set_files(files)

# We have to do this instead of data.dict() because
# data.dict() will convert all nested models to dicts
Expand Down Expand Up @@ -336,20 +287,21 @@ async def generator_wrapper() -> AsyncGenerator[bytes, None]:
return custom_exception_handler(None, e)
return AanaJSONResponse(content=output)

if file_upload_field:
files = File(None, description=file_upload_field.description)
else:
files = None

async def route_func(
request: Request,
body: str = Form(...),
files=files,
defer: bool = Query(
description="Defer execution of the endpoint to the task queue.",
default=False,
include_in_schema=aana_settings.task_queue.enabled,
),
):
form_data = await request.form()

files: dict[str, bytes] = {}
for field_name, field_value in form_data.items():
if isinstance(field_value, StarletteUploadFile):
files[field_name] = await field_value.read()
return await route_func_body(body=body, files=files, defer=defer)

return route_func
Expand Down
Loading
Loading