diff --git a/burr/core/action.py b/burr/core/action.py index cb9a7ac5..8ada705f 100644 --- a/burr/core/action.py +++ b/burr/core/action.py @@ -70,17 +70,40 @@ def run(self, state: State, **run_kwargs) -> dict: pass @property - def inputs(self) -> Union[list[str], tuple[list[str], list[str]]]: + def inputs( + self, + ) -> Tuple[Union[list[str], Dict[str, Type[Type]]], Union[list[str], Dict[str, Type[Type]]]]: """Represents inputs that are used for this to run. These correspond to the ``**run_kwargs`` in `run` above. Note that this has two possible return values: - 1. A list of strings -- these are the keys that are required to run the function - 2. A tuple of two lists of strings -- the first list is the required keys, the second is the optional keys + 1. A list of strings/dict of string -> ypes -- these are the keys that are required to run the function + 2. A tuple of two lists of strings/dict of strings -> types -- the first list is the required keys, the second is the optional keys :return: Either a list of strings (required inputs) or a tuple of two lists of strings (required and optional inputs) """ - return [] + return [], [] + + @property + def input_schema(self) -> Tuple[dict[str, Type[Type]], dict[str, Type[Type]]]: + """Returns the input schema for the function. + The input schema is a type that can be used to validate the input to the function. + + Note that this is separate from inputs() for backwards compatibility -- + inputs() can return a schema *or* a tuple of required and optional inputs. + + :return: Tuple of required inputs and optional inputs with attached types + """ + inputs = self.inputs + if len(inputs) == 1: + inputs = (inputs[0], {}) + out = [] + for input_spec in inputs: + if isinstance(input_spec, list): + out.append({key: Any for key in input_spec}) + else: + out.append(input_spec) + return tuple(out) @property def optional_and_required_inputs(self) -> tuple[set[str], set[str]]: @@ -213,11 +236,6 @@ def get_source(self) -> str: to display a different source""" return inspect.getsource(self.__class__) - def input_schema(self) -> Any: - """Returns the input schema for the action. - The input schema is a type that can be used to validate the input to the action""" - return None - def __repr__(self): read_repr = ", ".join(self.reads) if self.reads else "{}" write_repr = ", ".join(self.writes) if self.writes else "{}" @@ -534,7 +552,9 @@ def is_async(self) -> bool: # the following exist to share implementation between FunctionBasedStreamingAction and FunctionBasedAction # TODO -- think through the class hierarchy to simplify, for now this is OK -def derive_inputs_from_fn(bound_params: dict, fn: Callable) -> tuple[list[str], list[str]]: +def derive_inputs_from_fn( + bound_params: dict, fn: Callable +) -> tuple[dict[str, Type[Type]], dict[str, Type[Type]]]: """Derives inputs from the function, given the bound parameters. This assumes that the function has inputs named `state`, as well as any number of other kwarg-boundable parameters. @@ -543,20 +563,29 @@ def derive_inputs_from_fn(bound_params: dict, fn: Callable) -> tuple[list[str], :return: Required and optional inputs """ sig = inspect.signature(fn) - required_inputs, optional_inputs = [], [] + required_inputs, optional_inputs = {}, {} for param_name, param in sig.parameters.items(): if param_name != "state" and param_name not in bound_params: if param.default is inspect.Parameter.empty: # has no default means its required - required_inputs.append(param_name) + required_inputs[param_name] = ( + param.annotation if param.annotation != inspect.Parameter.empty else Any + ) else: # has a default means its optional - optional_inputs.append(param_name) + optional_inputs[param_name] = ( + param.annotation if param.annotation != inspect.Parameter.empty else Any + ) return required_inputs, optional_inputs FunctionBasedActionType = Union["FunctionBasedAction", "FunctionBasedStreamingAction"] +InputSpec = Tuple[Union[list[str], Dict[str, Type[Type]], Union[list[str], Dict[str, Type[Type]]]]] + + +OptionalInputType = Optional[InputSpec] + class FunctionBasedAction(SingleStepAction): ACTION_FUNCTION = "action_function" @@ -567,7 +596,9 @@ def __init__( reads: List[str], writes: List[str], bound_params: Optional[dict] = None, - input_spec: Optional[tuple[list[str], list[str]]] = None, + input_spec: Optional[ + Tuple[Union[list[str], Dict[str, Type[Type]]], Union[list[str], Dict[str, Type[Type]]]] + ] = None, originating_fn: Optional[Callable] = None, schema: ActionSchema = DEFAULT_SCHEMA, ): @@ -589,11 +620,9 @@ def __init__( self._inputs = ( derive_inputs_from_fn(self._bound_params, self._fn) if input_spec is None - else ( - [item for item in input_spec[0] if item not in self._bound_params], - [item for item in input_spec[1] if item not in self._bound_params], - ) + else input_spec ) + print(self._inputs, self._name) self._schema = schema @property @@ -609,7 +638,9 @@ def writes(self) -> list[str]: return self._writes @property - def inputs(self) -> tuple[list[str], list[str]]: + def inputs( + self, + ) -> Tuple[Union[list[str], Dict[str, Type[Type]]], Union[list[str], Dict[str, Type[Type]]]]: return self._inputs @property @@ -630,7 +661,7 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedAction": self._reads, self._writes, {**self._bound_params, **kwargs}, - input_spec=self._inputs, + input_spec=self.input_schema, originating_fn=self._originating_fn, schema=self._schema, ) @@ -1044,7 +1075,9 @@ def __init__( reads: List[str], writes: List[str], bound_params: Optional[dict] = None, - input_spec: Optional[tuple[list[str], list[str]]] = None, + input_spec: Optional[ + Tuple[Union[list[str], Dict[str, Type[Type]]], Union[list[str], Dict[str, Type[Type]]]] + ] = None, originating_fn: Optional[Callable] = None, schema: ActionSchema = DEFAULT_SCHEMA, ): @@ -1111,13 +1144,15 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedStreamingAction": self._reads, self._writes, {**self._bound_params, **kwargs}, - input_spec=self._inputs, + input_spec=self.input_schema, originating_fn=self._originating_fn, schema=self._schema, ) @property - def inputs(self) -> tuple[list[str], list[str]]: + def inputs( + self, + ) -> Tuple[Union[list[str], Dict[str, Type[Type]]], Union[list[str], Dict[str, Type[Type]]]]: return self._inputs @property diff --git a/burr/integrations/fastapi.py b/burr/integrations/fastapi.py new file mode 100644 index 00000000..f9ba3530 --- /dev/null +++ b/burr/integrations/fastapi.py @@ -0,0 +1,320 @@ +import dataclasses +import functools +from typing import Any, Callable, Dict, List, Literal, Optional, Type, TypeVar + +import pydantic +from fastapi import APIRouter, Request +from pydantic import BaseModel + +from burr.core.application import Application, ApplicationBuilder +from burr.core.graph import Graph +from burr.core.persistence import BaseStatePersister, SQLLitePersister +from burr.integrations.pydantic import PydanticTypingSystem, model_to_dict + +import examples.fastapi.application as email_assistant_application + +# current_directory = os.path.dirname(os.path.abspath(__file__)) + +# # Remove the current directory from sys.path +# if current_directory in sys.path: +# sys.path.remove(current_directory) + +# # Import the Pydantic library +# import pydantic +# import fastapi + +# # Add the current directory back to sys.path +# sys.path.insert(0, current_directory) + + +AppStateT = TypeVar("AppStateT", bound=pydantic.BaseModel) + + +class BurrFastAPIConfig(BaseModel): + terminating_actions: List[str] # list of actions after which to terminate + requires_input: List[str] + fixed_inputs: Optional[dict] = None # Fixed inputs to use for all requests + hide_state_fields: Optional[List[str]] = None # Fields to hide from the state + require_partition_keys: bool = False # Require partition keys for all requests + graph: Graph # partially built builder -- no uids or other keys yet + persister: BaseStatePersister # persister so we can handle storage + state_type: Type[pydantic.BaseModel] # state type + default_model: Optional[pydantic.BaseModel] = None + entrypoint: str + + class Config: + arbitrary_types_allowed = True + + # TODO -- expose tracking + hooks + + # This must be partially built (no uids or anything) + + +@functools.lru_cache(maxsize=1000) # TODO -- determine how to do this in a cleaner way +def _get_or_create( + application_id: str, partition_key: Optional[str], config: BurrFastAPIConfig +) -> Application: + return ( + ApplicationBuilder() + .with_graph(config.graph) + .with_state_persister(config.persister) + .with_identifiers(app_id=application_id, partition_key=partition_key) + .with_typing(PydanticTypingSystem(config.state_type)) + .initialize_from( + initializer=config.persister, + resume_at_next_action=True, + default_state=model_to_dict(config.default_model) if config.default_model else {}, + default_entrypoint=config.entrypoint, + ) + .build() + ) + + +async def _run_through( + app_id: Optional[str], partition_key: str, inputs: Dict[str, Any], config: BurrFastAPIConfig +): + """This advances the state machine, moving through to the next 'halting' point""" + app = _get_or_create(app_id, partition_key) + await app.arun( # Using this as a side-effect, we'll just get the state aft + halt_before=config.requires_input, # TODO -- ensure that it's not None + halt_after=config.terminating_actions, + inputs=inputs, + ) + return app.state.data, app.get_next_action() + + +@dataclasses.dataclass +class Endpoint: + path: List[str] # "/do/something/{application_id}/{partition_key}" + method: Literal["GET", "POST"] + body_type: Optional[Type[pydantic.BaseModel]] + response_type: Type[pydantic.BaseModel] + template: Literal["input", "get_or_create"] + config: BurrFastAPIConfig + internal_version: int = 0 + + def get_endpoint_handler(self) -> Callable: + """Gives the endpoint handler for this endpoint to be registered by FastAPI. + Returns a function that FastAPI can parse. + + :return: _description_ + """ + PartitionKeyType = str if self.config.require_partition_keys else Optional[str] + + async def get_or_create_handler( + request: Request, application_id: str, partition_key: PartitionKeyType + ): + app = _get_or_create(application_id, partition_key, self.config) + next_action = app.get_next_action() + return self.response_type( + state=app.state.data, + next_action=next_action.name if next_action is not None else None, + app_id=app.uid, + ) + + async def input_handler( + request: Request, application_id, partition_key: PartitionKeyType, body: self.body_type + ): + # TODO -- implement me! + state_output, next_action = await _run_through( + application_id, partition_key, body.dict(), self.config + ) + # TODO -- ensure this is of the same base-class, we're kind of hardcoding it + return self.response_type( + state=state_output, + next_action=next_action.name if next_action else None, + app_id=application_id, + ) + + if self.template == "input": + return input_handler + elif self.template == "get_or_create": + return get_or_create_handler + + +def _create_input_endpoint(action_name: str, config: BurrFastAPIConfig) -> Endpoint: + """Creates an endpoint for user-required data + + :param burr_app: Application to create endpoint for + :param action_name: Name of the action that that endpoint will *start* at + :param fixed_inputs: Fixed inputs -- these can be used by the endpoint + :return: Endpoint object that will be used to generate a FastAPI app + """ + action = config.graph.get_action(action_name) + if action is None: + raise ValueError(f"Action {action_name} not found in graph") + # Schema for action + required_inputs, optional_inputs = action.input_schema # each Dict[str, type] + # TODO -- create a pydantic model dynamically that has: + # 1. the required inputs as the field "inputs", minus any that are in the variable "fixed_inputs" (hidden) + # 2. the optional inputs in the field "inputs", with default to null + optional, minus any that are in the field "fixed_inputs" (hidden) + filtered_required = { + k: (v, ...) for k, v in required_inputs.items() if k not in (config.fixed_inputs or {}) + } + filtered_optional = { + k: (Optional[v], None) for k, v in optional_inputs.items() if (config.fixed_inputs or {}) + } + + # Combine filtered required and optional fields + inputs_fields = {**filtered_required, **filtered_optional} + + import pprint + + pprint.pprint(inputs_fields) + # Dynamically create the Inputs model + InputsModel = pydantic.create_model("InputsModel", **inputs_fields) + ResponseModel = pydantic.create_model( + "ResponseModel", + state=(config.state_type, ...), + next_action=(Optional[str], ...), + app_id=(str, ...), + ) + return Endpoint( + path=[f"/{action_name}/{{application_id}}/{{partition_key}}"], + method="POST", + body_type=InputsModel, + response_type=ResponseModel, + template="input", + config=config, + ) + + +def _create_get_or_create_endpoint(name: str, config: BurrFastAPIConfig) -> Endpoint: + return Endpoint( + path=[f"/{name}/{{application_id}}/{{partition_key}}"], + method="POST", + body_type=None, + response_type=config.state_type, + template="get_or_create", + config=config, + ) + + +def _gather_endpoints(config: BurrFastAPIConfig) -> List[Endpoint]: + actions_with_input = set(config.requires_input or []) + entrypoint = config.entrypoint + # terminating_actions = set(config.terminating_actions) + + endpoints = [] + for action in config.graph.actions: + if action.name in actions_with_input: + endpoints.append(_create_input_endpoint(action.name, config)) + if action.name == entrypoint: + endpoints.append(_create_input_endpoint(entrypoint, config)) + + endpoints.append(_create_get_or_create_endpoint("get_or_create", config)) + return endpoints + + +def _register_endpoint(router: APIRouter, endpoint: Endpoint): + if endpoint.method == "POST": + router.post( + "/".join(endpoint.path), + response_model=endpoint.response_type, + )(endpoint.get_endpoint_handler()) + elif endpoint.method == "GET": + router.get( + "/".join(endpoint.path), + response_model=endpoint.response_type, + )(endpoint.get_endpoint_handler()) + # TODO -- handle other types + + +# def _validate_and_extract_app_type( +# burr_app: Application[pydantic.BaseModel], +# ) -> Type[pydantic.BaseModel]: +# typing_system = burr_app.state.typing_system +# if not isinstance(typing_system, PydanticTypingSystem): +# raise ValueError( +# "Burr FastAPI requires a PydanticTypingSystem. Use with_typing(PydanticTypingSystem(MyStateModel(...))) to specify" +# ) +# return typing_system.state_type() + + +def _validate_terminating_actions(graph: Graph, terminating_actions: List[str]): + missing_actions = set(terminating_actions) - {action.name for action in graph.actions} + if missing_actions: + raise ValueError(f"Terminating actions {missing_actions} not found in graph") + + +def _validate_inputs_types(input_types: dict, action_name: str): + inputs_with_any_type = set() + for key, value in input_types.items(): + if value is Any: + inputs_with_any_type.add(key) + if inputs_with_any_type: + raise ValueError( + f"Action {action_name} has inputs with Any type: {inputs_with_any_type}." + "This means that they are not specified. Please specify by assigning" + "parameter types in the action definition" + ) + + +def _validate_require_input_actions(config: BurrFastAPIConfig): + requires_input_actions = set(config.requires_input or []) + missing_actions = requires_input_actions - {action.name for action in config.graph.actions} + if missing_actions: + raise ValueError(f"Actions {missing_actions} not found in graph") + burr_actions = {action.name: action for action in config.graph.actions} + for action_name in requires_input_actions: + action = burr_actions[action_name] + required_inputs, optional_inputs = action.input_schema + _validate_inputs_types(required_inputs, action_name) + _validate_inputs_types(optional_inputs, action_name) + + +def _validate_state_hide_fields(config: BurrFastAPIConfig, state_type: Type[pydantic.BaseModel]): + hidden_fields = set(config.hide_state_fields or []) + model_fields = set(state_type.model_fields.keys()) + missing_fields = hidden_fields - model_fields + if missing_fields: + raise ValueError( + f"Fields {missing_fields} not found in state model, ", + "but specified in hide_fields. Please remove from hide_fields or add to state model", + ) + + +def _validate_config(config: BurrFastAPIConfig): + app_type = config.state_type + _validate_terminating_actions(config.graph, config.terminating_actions) + _validate_require_input_actions(config) + _validate_state_hide_fields(config, app_type) + + +def expose(router: APIRouter, config: BurrFastAPIConfig): + """Exposes a burr app as a fastAPI app + + :param router: _description_ + :param burr_app: _description_ + :param config: _description_ + """ + _validate_config(config) + endpoints = _gather_endpoints(config) + for endpoint in endpoints: + _register_endpoint(router, endpoint) + + +if __name__ == "__main__": + import uvicorn + from fastapi import FastAPI + + # tracker = LocalTrackingClient(project="test_fastapi_autogen") + config = BurrFastAPIConfig( + terminating_actions=["final_result"], + requires_input=["clarify_instructions", "process_feedback"], + fixed_inputs={}, + hide_state_fields=[], + require_partition_keys=False, + graph=email_assistant_application.graph, + persister=SQLLitePersister(db_path=".sqllite.db", table_name="test1"), + state_type=email_assistant_application.EmailAssistantState, + default_model=email_assistant_application.EmailAssistantState(), + entrypoint="process_input", + ) + app = FastAPI() + router = APIRouter( + prefix="/email_assistant", + ) + expose(router, config) + app.include_router(router) + uvicorn.run(app, host="localhost", port=7245) diff --git a/burr/integrations/pydantic.py b/burr/integrations/pydantic.py index fdfe9774..bbdc3760 100644 --- a/burr/integrations/pydantic.py +++ b/burr/integrations/pydantic.py @@ -216,6 +216,7 @@ async def async_action_function(state: State, **kwargs) -> State: # TODO -- use the @action decorator directly # TODO -- ensure that the function is the right one -- specifically it probably won't show code in the UI # now + setattr( fn, FunctionBasedAction.ACTION_FUNCTION, diff --git a/burr/tracking/client.py b/burr/tracking/client.py index f9070e18..8c61ddbe 100644 --- a/burr/tracking/client.py +++ b/burr/tracking/client.py @@ -133,10 +133,7 @@ def copy(self) -> Self: pass -class LocalTrackingClient( - SyncTrackingClient, - BaseStateLoader, -): +class LocalTrackingClient(SyncTrackingClient, BaseStateLoader): """Tracker to track locally -- goes along with the Burr UI. Writes down the following: #. The whole application + debugging information (e.g. source code) to a file diff --git a/burr/tracking/server/run.py b/burr/tracking/server/run.py index f99c4602..8ddffad5 100644 --- a/burr/tracking/server/run.py +++ b/burr/tracking/server/run.py @@ -32,6 +32,7 @@ # dynamic importing due to the dashes (which make reading the examples on github easier) email_assistant = importlib.import_module("burr.examples.email-assistant.server") + email_assistant_typed = importlib.import_module("burr.examples.fastapi.server") chatbot = importlib.import_module("burr.examples.multi-modal-chatbot.server") streaming_chatbot = importlib.import_module("burr.examples.streaming-fastapi.server") @@ -249,6 +250,7 @@ async def version() -> dict: # Examples -- todo -- put them behind `if` statements app.include_router(chatbot.router, prefix="/api/v0/chatbot") app.include_router(email_assistant.router, prefix="/api/v0/email_assistant") +app.include_router(email_assistant_typed.router, prefix="/api/v0/email_assistant_typed") app.include_router(streaming_chatbot.router, prefix="/api/v0/streaming_chatbot") if SERVE_STATIC: diff --git a/telemetry/ui/src/api/index.ts b/telemetry/ui/src/api/index.ts index 35b03467..e8967da0 100644 --- a/telemetry/ui/src/api/index.ts +++ b/telemetry/ui/src/api/index.ts @@ -12,17 +12,22 @@ export type { ApplicationLogs } from './models/ApplicationLogs'; export type { ApplicationModel } from './models/ApplicationModel'; export type { ApplicationPage } from './models/ApplicationPage'; export type { ApplicationSummary } from './models/ApplicationSummary'; +export type { AppResponse } from './models/AppResponse'; export type { AttributeModel } from './models/AttributeModel'; export type { BackendSpec } from './models/BackendSpec'; export type { BeginEntryModel } from './models/BeginEntryModel'; export type { BeginSpanModel } from './models/BeginSpanModel'; +export { burr__examples__email_assistant__server__EmailAssistantState } from './models/burr__examples__email_assistant__server__EmailAssistantState'; export { ChatItem } from './models/ChatItem'; export { ChildApplicationModel } from './models/ChildApplicationModel'; +export type { ClarificationAnswers } from './models/ClarificationAnswers'; +export type { ClarificationQuestions } from './models/ClarificationQuestions'; export type { DraftInit } from './models/DraftInit'; -export { EmailAssistantState } from './models/EmailAssistantState'; +export type { Email } from './models/Email'; export type { EndEntryModel } from './models/EndEntryModel'; export type { EndSpanModel } from './models/EndSpanModel'; export type { EndStreamModel } from './models/EndStreamModel'; +export type { examples__fastapi__application__EmailAssistantState } from './models/examples__fastapi__application__EmailAssistantState'; export type { Feedback } from './models/Feedback'; export type { FirstItemStreamModel } from './models/FirstItemStreamModel'; export type { HTTPValidationError } from './models/HTTPValidationError'; diff --git a/telemetry/ui/src/api/models/AppResponse.ts b/telemetry/ui/src/api/models/AppResponse.ts new file mode 100644 index 00000000..e526c9f6 --- /dev/null +++ b/telemetry/ui/src/api/models/AppResponse.ts @@ -0,0 +1,10 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +import type { examples__fastapi__application__EmailAssistantState } from './examples__fastapi__application__EmailAssistantState'; +export type AppResponse = { + app_id: string; + next_step: 'process_input' | 'clarify_instructions' | 'process_feedback' | null; + state: examples__fastapi__application__EmailAssistantState; +}; diff --git a/telemetry/ui/src/api/models/ClarificationAnswers.ts b/telemetry/ui/src/api/models/ClarificationAnswers.ts new file mode 100644 index 00000000..98bdd15a --- /dev/null +++ b/telemetry/ui/src/api/models/ClarificationAnswers.ts @@ -0,0 +1,7 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export type ClarificationAnswers = { + answers: Array; +}; diff --git a/telemetry/ui/src/api/models/ClarificationQuestions.ts b/telemetry/ui/src/api/models/ClarificationQuestions.ts new file mode 100644 index 00000000..8ce3afbe --- /dev/null +++ b/telemetry/ui/src/api/models/ClarificationQuestions.ts @@ -0,0 +1,7 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export type ClarificationQuestions = { + question: Array; +}; diff --git a/telemetry/ui/src/api/models/Email.ts b/telemetry/ui/src/api/models/Email.ts new file mode 100644 index 00000000..7e050b8c --- /dev/null +++ b/telemetry/ui/src/api/models/Email.ts @@ -0,0 +1,8 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export type Email = { + subject: string; + contents: string; +}; diff --git a/telemetry/ui/src/api/models/EmailAssistantState.ts b/telemetry/ui/src/api/models/burr__examples__email_assistant__server__EmailAssistantState.ts similarity index 69% rename from telemetry/ui/src/api/models/EmailAssistantState.ts rename to telemetry/ui/src/api/models/burr__examples__email_assistant__server__EmailAssistantState.ts index 4a2ed93b..233ff84a 100644 --- a/telemetry/ui/src/api/models/EmailAssistantState.ts +++ b/telemetry/ui/src/api/models/burr__examples__email_assistant__server__EmailAssistantState.ts @@ -2,7 +2,7 @@ /* istanbul ignore file */ /* tslint:disable */ /* eslint-disable */ -export type EmailAssistantState = { +export type burr__examples__email_assistant__server__EmailAssistantState = { app_id: string; email_to_respond: string | null; response_instructions: string | null; @@ -11,9 +11,9 @@ export type EmailAssistantState = { drafts: Array; feedback_history: Array; final_draft: string | null; - next_step: EmailAssistantState.next_step; + next_step: burr__examples__email_assistant__server__EmailAssistantState.next_step; }; -export namespace EmailAssistantState { +export namespace burr__examples__email_assistant__server__EmailAssistantState { export enum next_step { PROCESS_INPUT = 'process_input', CLARIFY_INSTRUCTIONS = 'clarify_instructions', diff --git a/telemetry/ui/src/api/models/examples__fastapi__application__EmailAssistantState.ts b/telemetry/ui/src/api/models/examples__fastapi__application__EmailAssistantState.ts new file mode 100644 index 00000000..518ad1d9 --- /dev/null +++ b/telemetry/ui/src/api/models/examples__fastapi__application__EmailAssistantState.ts @@ -0,0 +1,18 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +import type { ClarificationAnswers } from './ClarificationAnswers'; +import type { ClarificationQuestions } from './ClarificationQuestions'; +import type { Email } from './Email'; +export type examples__fastapi__application__EmailAssistantState = { + email_to_respond?: string | null; + response_instructions?: string | null; + questions?: ClarificationQuestions | null; + answers?: ClarificationAnswers | null; + draft_history?: Array; + current_draft?: Email | null; + feedback_history?: Array; + feedback?: string | null; + final_draft?: string | null; +}; diff --git a/telemetry/ui/src/api/services/DefaultService.ts b/telemetry/ui/src/api/services/DefaultService.ts index 2d4ff459..6ee309c8 100644 --- a/telemetry/ui/src/api/services/DefaultService.ts +++ b/telemetry/ui/src/api/services/DefaultService.ts @@ -4,10 +4,11 @@ /* eslint-disable */ import type { ApplicationLogs } from '../models/ApplicationLogs'; import type { ApplicationPage } from '../models/ApplicationPage'; +import type { AppResponse } from '../models/AppResponse'; import type { BackendSpec } from '../models/BackendSpec'; +import type { burr__examples__email_assistant__server__EmailAssistantState } from '../models/burr__examples__email_assistant__server__EmailAssistantState'; import type { ChatItem } from '../models/ChatItem'; import type { DraftInit } from '../models/DraftInit'; -import type { EmailAssistantState } from '../models/EmailAssistantState'; import type { Feedback } from '../models/Feedback'; import type { IndexingJob } from '../models/IndexingJob'; import type { Project } from '../models/Project'; @@ -297,14 +298,14 @@ export class DefaultService { * @param projectId * @param appId * @param requestBody - * @returns EmailAssistantState Successful Response + * @returns burr__examples__email_assistant__server__EmailAssistantState Successful Response * @throws ApiError */ public static initializeDraftApiV0EmailAssistantCreateProjectIdAppIdPost( projectId: string, appId: string, requestBody: DraftInit - ): CancelablePromise { + ): CancelablePromise { return __request(OpenAPI, { method: 'POST', url: '/api/v0/email_assistant/create/{project_id}/{app_id}', @@ -330,14 +331,14 @@ export class DefaultService { * @param projectId * @param appId * @param requestBody - * @returns EmailAssistantState Successful Response + * @returns burr__examples__email_assistant__server__EmailAssistantState Successful Response * @throws ApiError */ public static answerQuestionsApiV0EmailAssistantAnswerQuestionsProjectIdAppIdPost( projectId: string, appId: string, requestBody: QuestionAnswers - ): CancelablePromise { + ): CancelablePromise { return __request(OpenAPI, { method: 'POST', url: '/api/v0/email_assistant/answer_questions/{project_id}/{app_id}', @@ -363,14 +364,14 @@ export class DefaultService { * @param projectId * @param appId * @param requestBody - * @returns EmailAssistantState Successful Response + * @returns burr__examples__email_assistant__server__EmailAssistantState Successful Response * @throws ApiError */ public static provideFeedbackApiV0EmailAssistantProvideFeedbackProjectIdAppIdPost( projectId: string, appId: string, requestBody: Feedback - ): CancelablePromise { + ): CancelablePromise { return __request(OpenAPI, { method: 'POST', url: '/api/v0/email_assistant/provide_feedback/{project_id}/{app_id}', @@ -394,13 +395,13 @@ export class DefaultService { * :return: The state of the application * @param projectId * @param appId - * @returns EmailAssistantState Successful Response + * @returns burr__examples__email_assistant__server__EmailAssistantState Successful Response * @throws ApiError */ public static getStateApiV0EmailAssistantStateProjectIdAppIdGet( projectId: string, appId: string - ): CancelablePromise { + ): CancelablePromise { return __request(OpenAPI, { method: 'GET', url: '/api/v0/email_assistant/state/{project_id}/{app_id}', @@ -427,6 +428,170 @@ export class DefaultService { url: '/api/v0/email_assistant/validate/{project_id}/{app_id}' }); } + /** + * Create New Application + * @param projectId + * @param appId + * @returns string Successful Response + * @throws ApiError + */ + public static createNewApplicationApiV0EmailAssistantTypedCreateNewProjectIdAppIdPost( + projectId: string, + appId: string + ): CancelablePromise { + return __request(OpenAPI, { + method: 'POST', + url: '/api/v0/email_assistant_typed/create_new/{project_id}/{app_id}', + path: { + project_id: projectId, + app_id: appId + }, + errors: { + 422: `Validation Error` + } + }); + } + /** + * Initialize Draft + * Endpoint to initialize the draft with the email and instructions + * + * :param project_id: ID of the project (used by telemetry tracking/storage) + * :param app_id: ID of the application (used to reference the app) + * :param draft_data: Data to initialize the draft + * :return: The state of the application after initialization + * @param projectId + * @param appId + * @param requestBody + * @returns AppResponse Successful Response + * @throws ApiError + */ + public static initializeDraftApiV0EmailAssistantTypedCreateProjectIdAppIdPost( + projectId: string, + appId: string, + requestBody: DraftInit + ): CancelablePromise { + return __request(OpenAPI, { + method: 'POST', + url: '/api/v0/email_assistant_typed/create/{project_id}/{app_id}', + path: { + project_id: projectId, + app_id: appId + }, + body: requestBody, + mediaType: 'application/json', + errors: { + 422: `Validation Error` + } + }); + } + /** + * Answer Questions + * Endpoint to answer questions the LLM provides + * + * :param project_id: ID of the project (used by telemetry tracking/storage) + * :param app_id: ID of the application (used to reference the app) + * :param question_answers: Answers to the questions + * :return: The state of the application after answering the questions + * @param projectId + * @param appId + * @param requestBody + * @returns AppResponse Successful Response + * @throws ApiError + */ + public static answerQuestionsApiV0EmailAssistantTypedAnswerQuestionsProjectIdAppIdPost( + projectId: string, + appId: string, + requestBody: QuestionAnswers + ): CancelablePromise { + return __request(OpenAPI, { + method: 'POST', + url: '/api/v0/email_assistant_typed/answer_questions/{project_id}/{app_id}', + path: { + project_id: projectId, + app_id: appId + }, + body: requestBody, + mediaType: 'application/json', + errors: { + 422: `Validation Error` + } + }); + } + /** + * Provide Feedback + * Endpoint to provide feedback to the LLM + * + * :param project_id: ID of the project (used by telemetry tracking/storage) + * :param app_id: ID of the application (used to reference the app) + * :param feedback: Feedback to provide to the LLM + * :return: The state of the application after providing feedback + * @param projectId + * @param appId + * @param requestBody + * @returns AppResponse Successful Response + * @throws ApiError + */ + public static provideFeedbackApiV0EmailAssistantTypedProvideFeedbackProjectIdAppIdPost( + projectId: string, + appId: string, + requestBody: Feedback + ): CancelablePromise { + return __request(OpenAPI, { + method: 'POST', + url: '/api/v0/email_assistant_typed/provide_feedback/{project_id}/{app_id}', + path: { + project_id: projectId, + app_id: appId + }, + body: requestBody, + mediaType: 'application/json', + errors: { + 422: `Validation Error` + } + }); + } + /** + * Get State + * Get the current state of the application + * + * :param project_id: ID of the project (used by telemetry tracking/storage) + * :param app_id: ID of the application (used to reference the app) + * :return: The state of the application + * @param projectId + * @param appId + * @returns AppResponse Successful Response + * @throws ApiError + */ + public static getStateApiV0EmailAssistantTypedStateProjectIdAppIdGet( + projectId: string, + appId: string + ): CancelablePromise { + return __request(OpenAPI, { + method: 'GET', + url: '/api/v0/email_assistant_typed/state/{project_id}/{app_id}', + path: { + project_id: projectId, + app_id: appId + }, + errors: { + 422: `Validation Error` + } + }); + } + /** + * Validate Environment + * Validate the environment + * @returns any Successful Response + * @throws ApiError + */ + public static validateEnvironmentApiV0EmailAssistantTypedValidateProjectIdAppIdGet(): CancelablePromise< + string | null + > { + return __request(OpenAPI, { + method: 'GET', + url: '/api/v0/email_assistant_typed/validate/{project_id}/{app_id}' + }); + } /** * Chat Response * Chat response endpoint. User passes in a prompt and the system returns the diff --git a/telemetry/ui/src/examples/EmailAssistant.tsx b/telemetry/ui/src/examples/EmailAssistant.tsx index 42f80fce..4029cda3 100644 --- a/telemetry/ui/src/examples/EmailAssistant.tsx +++ b/telemetry/ui/src/examples/EmailAssistant.tsx @@ -4,9 +4,11 @@ import { ApplicationSummary, DefaultService, DraftInit, - EmailAssistantState, + // examples__fastapi__application__EmailAssistantState as EmailAssistantState, + AppResponse, Feedback, - QuestionAnswers + QuestionAnswers, + Email } from '../api'; import { useEffect, useState } from 'react'; import { useMutation, useQuery } from 'react-query'; @@ -77,19 +79,20 @@ export const InitialDraftView = (props: { }; export const SubmitAnswersView = (props: { - state: EmailAssistantState; + state: AppResponse; submitAnswers: (questions: QuestionAnswers) => void; questions: string[] | null; answers: string[] | null; isLoading: boolean; }) => { - const questions = props.state.questions || []; + const questions = props.state.state.questions?.question || []; const [answers, setAnswers] = useState(props.answers || questions.map(() => '')); const editMode = props.isLoading || props.answers === null; + const { state } = props.state; return (
- {(props.state.questions || []).map((question, index) => { + {(state.questions?.question || []).map((question, index) => { return ( @@ -121,9 +124,9 @@ export const SubmitAnswersView = (props: { }; export const SubmitFeedbackView = (props: { - state: EmailAssistantState; + state: AppResponse; submitFeedback: (feedbacks: Feedback) => void; - drafts: string[] | null; + drafts: Email[] | null; feedbacks: string[] | null; isLoading: boolean; }) => { @@ -141,7 +144,7 @@ export const SubmitFeedbackView = (props: { <>
-                  {draft}
+                  {draft.subject + '\n' + draft.contents || ''}
                 

@@ -185,18 +188,21 @@ export const SubmitFeedbackView = (props: { export const EmailAssistant = (props: { projectId: string; appId: string | undefined }) => { // starts off as null - const [emailAssistantState, setEmailAssistantState] = useState(null); + const [appResponse, setAppResponse] = useState(null); const { data: validationData, isLoading: isValidationLoading } = useQuery( ['valid', props.projectId, props.appId], - DefaultService.validateEnvironmentApiV0EmailAssistantValidateProjectIdAppIdGet + DefaultService.validateEnvironmentApiV0EmailAssistantTypedValidateProjectIdAppIdGet ); useEffect(() => { if (props.appId !== undefined) { // TODO -- handle errors - DefaultService.getStateApiV0EmailAssistantStateProjectIdAppIdGet(props.projectId, props.appId) + DefaultService.getStateApiV0EmailAssistantTypedStateProjectIdAppIdGet( + props.projectId, + props.appId + ) .then((data) => { - setEmailAssistantState(data); // we want to initialize the chat history + setAppResponse(data); // we want to initialize the chat history }) .catch((e) => { // eslint-disable-next-line @@ -209,7 +215,7 @@ export const EmailAssistant = (props: { projectId: string; appId: string | undef // TODO -- handle errors ['emailAssistant', props.projectId, props.appId], () => - DefaultService.getStateApiV0EmailAssistantStateProjectIdAppIdGet( + DefaultService.getStateApiV0EmailAssistantTypedStateProjectIdAppIdGet( props.projectId, props.appId || '' // TODO -- find a cleaner way of doing a skip-token like thing here // This is skipped if the appId is undefined so this is just to make the type-checker happy @@ -217,48 +223,48 @@ export const EmailAssistant = (props: { projectId: string; appId: string | undef { enabled: props.appId !== undefined, onSuccess: (data) => { - setEmailAssistantState(data); // when its succesful we want to set the displayed chat history + setAppResponse(data); // when its succesful we want to set the displayed chat history } } ); const submitInitialMutation = useMutation( (draftData: DraftInit) => - DefaultService.initializeDraftApiV0EmailAssistantCreateProjectIdAppIdPost( + DefaultService.initializeDraftApiV0EmailAssistantTypedCreateProjectIdAppIdPost( props.projectId, props.appId || 'create_new', draftData ), { onSuccess: (data) => { - setEmailAssistantState(data); + setAppResponse(data); } } ); const submitAnswersMutation = useMutation( (answers: QuestionAnswers) => - DefaultService.answerQuestionsApiV0EmailAssistantAnswerQuestionsProjectIdAppIdPost( + DefaultService.answerQuestionsApiV0EmailAssistantTypedAnswerQuestionsProjectIdAppIdPost( props.projectId, props.appId || '', answers ), { onSuccess: (data) => { - setEmailAssistantState(data); + setAppResponse(data); } } ); const submitFeedbackMutation = useMutation( (feedbacks: Feedback) => - DefaultService.provideFeedbackApiV0EmailAssistantProvideFeedbackProjectIdAppIdPost( + DefaultService.provideFeedbackApiV0EmailAssistantTypedProvideFeedbackProjectIdAppIdPost( props.projectId, props.appId || '', feedbacks ), { onSuccess: (data) => { - setEmailAssistantState(data); + setAppResponse(data); } } ); @@ -273,13 +279,12 @@ export const EmailAssistant = (props: { projectId: string; appId: string | undef return ; } const displayValidationError = validationData !== null; - const displayInstructions = emailAssistantState === null && !displayValidationError; - const displayInitialDraft = emailAssistantState !== null; - const displaySubmitAnswers = - displayInitialDraft && emailAssistantState.next_step !== 'process_input'; + const displayInstructions = appResponse === null && !displayValidationError; + const displayInitialDraft = appResponse !== null; + const displaySubmitAnswers = displayInitialDraft && appResponse.next_step !== 'process_input'; const displaySubmitFeedback = - displaySubmitAnswers && emailAssistantState.next_step !== 'clarify_instructions'; - const displayFinalDraft = displaySubmitFeedback && emailAssistantState.next_step === null; + displaySubmitAnswers && appResponse.next_step !== 'clarify_instructions'; + const displayFinalDraft = displaySubmitFeedback && appResponse.next_step === null; return (

{'Learn Burr '}

@@ -303,30 +308,30 @@ export const EmailAssistant = (props: { projectId: string; appId: string | undef submitInitial={(initial) => { submitInitialMutation.mutate(initial); }} - responseInstructions={emailAssistantState?.response_instructions} - emailToRespond={emailAssistantState?.email_to_respond} + responseInstructions={appResponse.state?.response_instructions || null} + emailToRespond={appResponse.state?.email_to_respond || null} /> )}
{displaySubmitAnswers && ( { submitAnswersMutation.mutate(answers); }} - questions={emailAssistantState.questions} - answers={emailAssistantState.answers} + questions={appResponse.state.questions?.question || null} + answers={appResponse.state.answers?.answers || null} isLoading={anyMutationLoading} /> )} {displaySubmitFeedback && ( { submitFeedbackMutation.mutate(feedbacks); }} - drafts={emailAssistantState.drafts} - feedbacks={emailAssistantState.feedback_history} + drafts={appResponse.state.draft_history || null} + feedbacks={appResponse.state.feedback_history || null} isLoading={anyMutationLoading} /> )} @@ -335,7 +340,9 @@ export const EmailAssistant = (props: { projectId: string; appId: string | undef

Final Draft

-              {emailAssistantState.drafts?.[emailAssistantState.drafts.length - 1]}
+              {appResponse.state.current_draft?.subject +
+                '\n' +
+                appResponse.state.current_draft?.contents || ''}