Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple API interface from batching configuration #412

Open
marrrcin opened this issue Jan 14, 2025 · 2 comments
Open

Decouple API interface from batching configuration #412

marrrcin opened this issue Jan 14, 2025 · 2 comments
Labels
enhancement New feature or request

Comments

@marrrcin
Copy link

🚀 Feature

Implement consistent predict method interface independent of batching configuration in LitServe.

Motivation

Example code I'm using ⤵️
import random
import litserve as ls
import os
import time

class LitServeBatchingDemoAPI(ls.LitAPI):
    def setup(self, device):
        print(f"Loading models in process {os.getpid()}")

    def decode_request(self, request):
        return request["inputs"] 
    
    def predict(self, batch):
        print("Received batch of size", len(batch), batch)
        results = [random.random() for _ in batch]
        time.sleep(1.5)
        return results

    def encode_response(self, output):
        return {"output": output}

if __name__ == "__main__":
    print(f"Starting server in process {os.getpid()}")
    server = ls.LitServer(LitServeBatchingDemoAPI(),
                          workers_per_device=1,
                          )
    server.run(port=8000)

Currently, enabling batching in LitServe (by setting max_batch_size and batch_timeout) changes the expected implementation of the predict method in LitAPI subclasses. This creates several issues:

  1. The same API implementation behaves differently based on server configuration parameters
  2. Developers need to maintain different implementations or add conditional logic based on whether batching is enabled
  3. It violates the principle of separation of concerns - server configuration parameters should not affect the API contract

For example, with batching disabled, sending:

requests.post("http://127.0.0.1:8000/predict", json={"inputs": "my-input"})
def predict(self, batch):
    # batch is a single input, len(batch) will return length of the string
    print("Received batch of size", len(batch), batch)

With batching enabled (max_batch_size>=2):

def predict(self, batch):
    # batch is a list of inputs
    print("Received batch of size", len(batch), batch)

This inconsistency makes it harder to maintain and test APIs, especially when batching configuration might change between development and production environments.

Pitch

LitServe should provide a consistent interface for the predict method regardless of batching configuration.

Additional context

Similar serving frameworks like TorchServe and RayServe have similar approaches to batching, but this doesn't mean LitServe can't improve upon their design. A consistent API contract would make LitServe more intuitive and easier to use correctly.

More context in recent Discord discussion.

The proposed change:

  • Maintains backward compatibility when batching is enabled
  • Simplifies API implementation by providing a consistent interface
  • Follows the principle of least surprise
  • Makes testing easier as there's only one behavior to test
  • Reduces potential bugs from incorrect handling of single vs batched inputs
@marrrcin marrrcin added the enhancement New feature or request label Jan 14, 2025
@lantiga
Copy link
Collaborator

lantiga commented Jan 14, 2025

hey @marrrcin I see what you're saying, it's the usual tension of making it easy to start and making it possible to scale

btw are you proposing we always pass a collection or tensor with a batch dimension of 1 as the input to predict even when batching is off?

what if we found a way to opt-in for the API to be batched even when batching is off? this way we would retain backward compatibility and keep the story simple for users that don't care / don't want to know about batching, while still achieving your goal (make the API implementation independent from batching)

there are a few ways to do so, one could be setting an attribute of the instance in setup (self.batched = True), or relying on the name of the argument to predict (if it contains batch*, then it's batched - weird but pytest works that way), or passing a batched=True to the constructor of the API, or passing a batched_api=True to LitServer (although this pertains to the API rather than the server). I'm not particularly fond of any of them but I'm sure we can find something acceptable.

@marrrcin
Copy link
Author

To be honest, any solution that will unify the behaviour is fine by me. The

always pass a collection or tensor with a batch dimension of 1 as the input to predict even when batching is off

sounds the most reasonable and standard to me though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants