Skip to content

Commit

Permalink
feat: updated queue implementation (#7)
Browse files Browse the repository at this point in the history
* feat: updated queue implementation

this brings the dart/flutter client to parity with the other clients for queue-based apis

* chore: update pucspec
  • Loading branch information
drochetti authored Sep 11, 2024
1 parent 74bf371 commit 79132b9
Show file tree
Hide file tree
Showing 12 changed files with 276 additions and 151 deletions.
42 changes: 13 additions & 29 deletions example/lib/main.dart
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import 'package:fal_client/fal_client.dart';
import 'package:flutter/material.dart';
import 'package:image_picker/image_picker.dart';

import 'types.dart';

// You can use the proxyUrl to protect your credentials in production.
// final fal = FalClient.withProxy('http://localhost:3333/api/_fal/proxy');
// final fal = FalClient.withProxy('http://localhost:3333/api/fal/proxy');

// You can also use the credentials locally for development, but make sure
// you protected your credentials behind a proxy in production.
final fal = FalClient.withCredentials('FAL_KEY_ID:FAL_KEY_SECRET');
final fal = FalClient.withCredentials('FAL_KEY');

void main() {
runApp(const FalSampleApp());
Expand Down Expand Up @@ -42,29 +41,26 @@ class TextoToImageScreen extends StatefulWidget {
}

class _TextoToImageScreenState extends State<TextoToImageScreen> {
final ImagePicker _picker = ImagePicker();
XFile? _image;
final TextEditingController _promptController = TextEditingController();
String? _generatedImageUrl;
bool _isProcessing = false;

Future<String> generateImage(XFile image, String prompt) async {
final result = await fal.subscribe(textToImageId, input: {
'prompt': prompt,
'image_url': image,
});
return result['image']['url'] as String;
Future<String> generateImage(String prompt) async {
final output = await fal.subscribe("fal-ai/flux/dev",
input: {
'prompt': prompt,
},
mode: SubscriptionMode.pollingWithInterval(Duration(seconds: 1)));
print(output.requestId);
final data = FluxOutput.fromMap(output.data);
return data.images.first.url;
}

void _onGenerateImage() async {
if (_image == null || _promptController.text.isEmpty) {
// Handle error: either image not selected or prompt not entered
return;
}
setState(() {
_isProcessing = true;
});
String imageUrl = await generateImage(_image!, _promptController.text);
String imageUrl = await generateImage(_promptController.text);
setState(() {
_generatedImageUrl = imageUrl;
_isProcessing = false;
Expand All @@ -75,26 +71,14 @@ class _TextoToImageScreenState extends State<TextoToImageScreen> {
Widget build(BuildContext context) {
return Scaffold(
appBar: AppBar(
title: const Text('Illusion Diffusion'),
title: const Text('FLUX.1 [dev]'),
),
body: Padding(
padding: const EdgeInsets.all(16.0),
child: Column(
mainAxisAlignment: MainAxisAlignment.start,
crossAxisAlignment: CrossAxisAlignment.stretch,
children: <Widget>[
ElevatedButton(
onPressed: () async {
final XFile? image =
await _picker.pickImage(source: ImageSource.gallery);
setState(() {
_image = image;
});
},
child: const Text('Pick Image'),
),
// if (_image != null)
// Image,
TextFormField(
controller: _promptController,
decoration: const InputDecoration(labelText: 'Imagine...'),
Expand Down
19 changes: 11 additions & 8 deletions example/lib/types.dart
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
class IllusionDiffusionResult {
final ImageRef image;
class FluxOutput {
final List<ImageRef> images;
final int seed;

IllusionDiffusionResult({required this.image, required this.seed});
FluxOutput({
required this.images,
required this.seed,
});

factory IllusionDiffusionResult.fromMap(Map<String, dynamic> json) {
return IllusionDiffusionResult(
image: ImageRef.fromMap(json['image'] as Map<String, dynamic>),
factory FluxOutput.fromMap(Map<String, dynamic> json) {
return FluxOutput(
images: ((json['images'] ?? []) as List<dynamic>)
.map((e) => ImageRef.fromMap(e as Map<String, dynamic>))
.toList(),
seed: (json['seed'] * 1).round(),
);
}
Expand All @@ -27,5 +32,3 @@ class ImageRef {
);
}
}

const textToImageId = '54285744-illusion-diffusion';
15 changes: 15 additions & 0 deletions lib/src/auth.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import './config.dart';
import './http.dart';

const _defaultTokenExpiration = 180;

Future<String> createJwtToken(
{required List<String> apps,
required Config config,
int expiration = _defaultTokenExpiration}) async {
final response = await sendRequest("https://rest.alpha.fal.ai/tokens/",
config: config,
method: "POST",
input: {"allowed_apps": apps, "token_expiration": expiration});
return response as String;
}
157 changes: 96 additions & 61 deletions lib/src/client.dart
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,30 @@ import './exception.dart';
import './http.dart';
import './queue.dart';
import './storage.dart';
import 'common.dart';

abstract class SubscriptionMode {
static const SubscriptionMode streaming = StreamingMode();
static const SubscriptionMode polling = PollingMode();

static SubscriptionMode pollingWithInterval(Duration interval) {
return PollingMode(pollInterval: interval);
}

const SubscriptionMode();
}

class PollingMode extends SubscriptionMode {
final Duration pollInterval;

const PollingMode({
this.pollInterval = const Duration(milliseconds: 500),
}) : super();
}

class StreamingMode extends SubscriptionMode {
const StreamingMode() : super();
}

/// The main client class that provides access to simple API model usage,
/// as well as access to the [queue] and [storage] APIs.
Expand All @@ -12,135 +36,146 @@ import './storage.dart';
/// ```dart
/// import 'package:fal_client/client.dart';
///
/// final fal = FalClient.withCredentials("fal_key_id:fal_key_secret");
/// final fal = FalClient.withCredentials("FAL_KEY");
///
/// void main() async {
/// // check https://fal.ai/models for the available models
/// final result = await fal.subscribe('text-to-image', input: {
/// final output = await fal.subscribe('fal-ai/flux/dev', input: {
/// 'prompt': 'a cute shih-tzu puppy',
/// 'model_name': 'stabilityai/stable-diffusion-xl-base-1.0',
/// });
/// print(result);
/// print(output.data);
/// print(output.requestId);
/// }
/// ```
abstract class Client {
abstract class FalClient {
/// The queue client with specific methods to interact with the queue API.
///
/// **Note:** that the [subscribe] method is a convenience method that uses the
/// [queue] client to submit a request and poll for the result.
Queue get queue;
QueueClient get queue;

/// The storage client with specific methods to interact with the storage API.
///
/// **Note:** both [run] and [subscribe] auto-upload files using the [storage]
/// when an [XFile] is passed as an input property value.
Storage get storage;
StorageClient get storage;

/// Sends a request to the given [id], an optional [path]. This method
/// Sends a request to the given [endpointId]. This method
/// is a direct request to the model API and it waits for the processing
/// to complete before returning the result.
///
/// This is useful for short running requests, but it's not recommended for
/// long running requests, for those see [submit].
Future<Map<String, dynamic>> run(
String id, {
Future<FalOutput> run(
String endpointId, {
String method = 'post',
String path = '',
Map<String, dynamic>? input,
});

/// Submits a request to the given [id], an optional [path]. This method
/// Submits a request to the given [endpointId]. This method
/// uses the [queue] API to submit the request and poll for the result.
///
/// This is useful for long running requests, and it's the preffered way
/// to interact with the model APIs.
Future<Map<String, dynamic>> subscribe(
String id, {
String path = '',
Map<String, dynamic>? input,
int pollInterval = 3000,
bool logs = false,
});
///
/// The [webhookUrl] is the URL where the server will send the result once
/// the request is completed. This is particularly useful when you want to
/// receive the result in a different server and for long running requests.
Future<FalOutput> subscribe(String endpointId,
{Map<String, dynamic>? input,
SubscriptionMode mode,
Duration timeout = const Duration(minutes: 5),
bool logs = false,
String? webhookUrl,
Function(String)? onEnqueue,
Function(QueueStatus)? onQueueUpdate});

factory FalClient.withProxy(String proxyUrl) {
return FalClientImpl(config: Config(proxyUrl: proxyUrl));
}

factory FalClient.withCredentials(String credentials) {
return FalClientImpl(config: Config(credentials: credentials));
}
}

/// The default implementation of the [Client] contract.
class FalClient implements Client {
class FalClientImpl implements FalClient {
final Config config;

@override
final Queue queue;
final QueueClient queue;

@override
final Storage storage;
final StorageClient storage;

FalClient({
FalClientImpl({
required this.config,
}) : queue = QueueClient(config: config),
storage = StorageClient(config: config);

factory FalClient.withProxy(String proxyUrl) {
return FalClient(config: Config(proxyUrl: proxyUrl));
}

factory FalClient.withCredentials(String credentials) {
return FalClient(config: Config(credentials: credentials));
}
}) : queue = QueueClientImpl(config: config),
storage = StorageClientImpl(config: config);

@override
Future<Map<String, dynamic>> run(
String id, {
Future<FalOutput> run(
String endpointId, {
String method = 'post',
String path = '',
Map<String, dynamic>? input,
}) async {
final transformedInput =
input != null ? await storage.transformInput(input) : null;
return await sendRequest(
id,
final response = await sendRequest(
endpointId,
config: config,
method: method,
input: transformedInput,
);
return convertResponseToOutput(response);
}

@override
Future<Map<String, dynamic>> subscribe(String id,
{String path = '',
Map<String, dynamic>? input,
int pollInterval = 3000, // 3 seconds
int timeout = 300000, // 5 minutes
Future<FalOutput> subscribe(String endpointId,
{Map<String, dynamic>? input,
SubscriptionMode mode = SubscriptionMode.polling,
Duration timeout = const Duration(minutes: 5),
bool logs = false,
String? webhookUrl,
Function(String)? onEnqueue,
Function(QueueStatus)? onQueueUpdate}) async {
final transformedInput =
input != null ? await storage.transformInput(input) : null;
final enqueued =
await queue.submit(id, input: transformedInput, path: path);
final enqueued = await queue.submit(endpointId,
input: transformedInput, webhookUrl: webhookUrl);
if (onEnqueue != null) {
onEnqueue(enqueued.requestId);
}
final requestId = enqueued.requestId;

if (onEnqueue != null) {
onEnqueue(requestId);
}

return _pollForResult(
id,
requestId: requestId,
logs: logs,
pollInterval: pollInterval,
timeout: timeout,
onQueueUpdate: onQueueUpdate,
);
if (mode is PollingMode) {
return _pollForResult(
endpointId,
requestId: requestId,
logs: logs,
pollInterval: mode.pollInterval,
timeout: timeout,
onQueueUpdate: onQueueUpdate,
);
}

throw UnimplementedError('Streaming mode is not yet implemented.');
}

Future<Map<String, dynamic>> _pollForResult(
String id, {
Future<FalOutput> _pollForResult(
String endpointId, {
required String requestId,
required bool logs,
required int pollInterval,
required int timeout,
required Duration pollInterval,
required Duration timeout,
Function(QueueStatus)? onQueueUpdate,
}) async {
final expiryTime = DateTime.now().add(Duration(milliseconds: timeout));
final expiryTime = DateTime.now().add(timeout);

while (true) {
if (DateTime.now().isAfter(expiryTime)) {
Expand All @@ -149,16 +184,16 @@ class FalClient implements Client {
status: 408);
}
final queueStatus =
await queue.status(id, requestId: requestId, logs: logs);
await queue.status(endpointId, requestId: requestId, logs: logs);

if (onQueueUpdate != null) {
onQueueUpdate(queueStatus);
}

if (queueStatus is CompletedStatus) {
return await queue.result(id, requestId: requestId);
return await queue.result(endpointId, requestId: requestId);
}
await Future.delayed(Duration(milliseconds: pollInterval));
await Future.delayed(pollInterval);
}
}
}
Loading

0 comments on commit 79132b9

Please sign in to comment.