From 6fba65678c2de7a1a58124b95767cf49911ad0c2 Mon Sep 17 00:00:00 2001 From: Myles Scolnick Date: Mon, 30 Sep 2024 15:42:32 -0400 Subject: [PATCH] feat: mo.ui.chat() (#2436) * mo.ai.chat() * progress * example * improvements * ignore imports * ignore again * cr comments * docs * more tests * fixes * docs * more docs * copyright * add variable tempalting * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- docs/api/inputs/chat.md | 135 +++++ docs/api/inputs/index.md | 2 + examples/ai/simple_chatbot.py | 76 +++ frontend/src/components/ui/input.tsx | 3 +- frontend/src/plugins/impl/chat/ChatPlugin.tsx | 74 +++ frontend/src/plugins/impl/chat/chat-ui.tsx | 494 ++++++++++++++++++ frontend/src/plugins/impl/chat/types.ts | 26 + frontend/src/plugins/plugins.ts | 2 + marimo/__init__.py | 22 +- marimo/_dependencies/dependencies.py | 1 + marimo/_plugins/ai/__init__.py | 10 + marimo/_plugins/ui/__init__.py | 10 +- marimo/_plugins/ui/_impl/chat/chat.py | 168 ++++++ marimo/_plugins/ui/_impl/chat/convert.py | 62 +++ marimo/_plugins/ui/_impl/chat/models.py | 229 ++++++++ marimo/_plugins/ui/_impl/chat/types.py | 79 +++ marimo/_plugins/ui/_impl/chat/utils.py | 35 ++ marimo/_plugins/validators.py | 1 + .../packages/module_name_to_pypi_name.py | 1 + marimo/_smoke_tests/chat/chatbot.py | 229 ++++++++ tests/_plugins/ui/_impl/chat/test_chat.py | 166 ++++++ .../ui/_impl/chat/test_chat_convert.py | 185 +++++++ .../_plugins/ui/_impl/chat/test_chat_model.py | 23 + 23 files changed, 2019 insertions(+), 14 deletions(-) create mode 100644 docs/api/inputs/chat.md create mode 100644 examples/ai/simple_chatbot.py create mode 100644 frontend/src/plugins/impl/chat/ChatPlugin.tsx create mode 100644 frontend/src/plugins/impl/chat/chat-ui.tsx create mode 100644 frontend/src/plugins/impl/chat/types.ts create mode 100644 marimo/_plugins/ai/__init__.py create mode 100644 marimo/_plugins/ui/_impl/chat/chat.py create mode 100644 marimo/_plugins/ui/_impl/chat/convert.py create mode 100644 marimo/_plugins/ui/_impl/chat/models.py create mode 100644 marimo/_plugins/ui/_impl/chat/types.py create mode 100644 marimo/_plugins/ui/_impl/chat/utils.py create mode 100644 marimo/_smoke_tests/chat/chatbot.py create mode 100644 tests/_plugins/ui/_impl/chat/test_chat.py create mode 100644 tests/_plugins/ui/_impl/chat/test_chat_convert.py create mode 100644 tests/_plugins/ui/_impl/chat/test_chat_model.py diff --git a/docs/api/inputs/chat.md b/docs/api/inputs/chat.md new file mode 100644 index 00000000000..a0fe6f31ae8 --- /dev/null +++ b/docs/api/inputs/chat.md @@ -0,0 +1,135 @@ +# Chat + +```{eval-rst} +.. marimo-embed:: + :size: large + + @app.cell + def __(): + def simple_echo_model(messages, config): + return f"You said: {messages[-1].content}" + + mo.ui.chat( + simple_echo_model, + prompts=["Hello", "How are you?"], + show_configuration_controls=True + ) + return +``` + +The chat UI element provides an interactive chatbot interface for conversations. It can be customized with different models, including built-in AI models or custom functions. + +```{eval-rst} +.. autoclass:: marimo.ui.chat + :members: + + .. autoclasstoc:: marimo._plugins.ui._impl.chat.chat.chat +``` + +## Basic Usage + +Here's a simple example using a custom echo model: + +```python +import marimo as mo + +def echo_model(messages, config): + return f"Echo: {messages[-1].content}" + +chat = mo.ui.chat(echo_model, prompts=["Hello", "How are you?"]) +chat +``` + +## Using a Built-in AI Model + +You can use marimo's built-in AI models, such as OpenAI's GPT: + +```python +import marimo as mo + +chat = mo.ui.chat( + mo.ai.openai( + "gpt-4", + system_message="You are a helpful assistant.", + ), + show_configuration_controls=True +) +chat +``` + +## Accessing Chat History + +You can access the chat history using the `value` attribute: + +```python +chat.value +``` + +This returns a list of `ChatMessage` objects, each containing `role` and `content` attributes. + +```{eval-rst} +.. autoclass:: ChatMessage + :members: + + .. autoclasstoc:: marimo._plugins.ui._impl.chat.types.ChatMessage +``` + +## Custom Model with Additional Context + +Here's an example of a custom model that uses additional context: + +```python +import marimo as mo + +def rag_model(messages, config): + question = messages[-1].content + docs = find_relevant_docs(question) + context = "\n".join(docs) + prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:" + response = query_llm(prompt, config) + return response + +mo.ui.chat(rag_model) +``` + +This example demonstrates how you can implement a Retrieval-Augmented Generation (RAG) model within the chat interface. + +## Built-in Models + +marimo provides several built-in AI models that you can use with the chat UI element. + +```python +import marimo as mo + +mo.ui.chat( + mo.ai.openai( + "gpt-4", + system_message="You are a helpful assistant.", + api_key="sk-...", + ), + show_configuration_controls=True +) + +mo.ui.chat( + mo.ai.anthropic( + "claude-3-5-sonnet-20240602", + system_message="You are a helpful assistant.", + api_key="sk-...", + ), + show_configuration_controls=True +) +``` + +```{eval-rst} +.. autoclass:: marimo.ai.models.openai + :members: + + .. autoclasstoc:: marimo._plugins.ui._impl.chat.models.openai +``` + +```{eval-rst} +.. autoclass:: marimo.ai.models.anthropic + :members: + + .. autoclasstoc:: marimo._plugins.ui._impl.chat.models.anthropic +``` diff --git a/docs/api/inputs/index.md b/docs/api/inputs/index.md index 2a56eb49595..def6e31f004 100644 --- a/docs/api/inputs/index.md +++ b/docs/api/inputs/index.md @@ -8,6 +8,7 @@ array batch button + chat checkbox code_editor dataframe @@ -44,6 +45,7 @@ powerful notebooks and apps. These elements are available in `marimo.ui`. marimo.ui.array marimo.ui.batch marimo.ui.button + marimo.ui.chat marimo.ui.checkbox marimo.ui.code_editor marimo.ui.dataframe diff --git a/examples/ai/simple_chatbot.py b/examples/ai/simple_chatbot.py new file mode 100644 index 00000000000..e1987a1656d --- /dev/null +++ b/examples/ai/simple_chatbot.py @@ -0,0 +1,76 @@ +# /// script +# requires-python = ">=3.11" +# dependencies = [ +# "ell-ai==0.0.12", +# "marimo", +# "openai==1.50.1", +# ] +# /// + +import marimo + +__generated_with = "0.8.20" +app = marimo.App(width="medium") + + +@app.cell(hide_code=True) +def __(mo): + mo.md(r"""# Simple chatbot 🤖""") + return + + +@app.cell(hide_code=True) +def __(mo): + import os + + os_key = os.environ.get("OPENAI_API_KEY") + input_key = mo.ui.text(label="OpenAI API key", kind="password") + input_key if not os_key else None + return input_key, os, os_key + + +@app.cell(hide_code=True) +def __(input_key, mo, os_key): + # Initialize a client + openai_key = os_key or input_key.value + + mo.stop( + not openai_key, + "Please set the OPENAI_API_KEY environment variable or provide it in the input field", + ) + + import ell + import openai + + # Create an openai client + client = openai.Client(api_key=openai_key) + return client, ell, openai, openai_key + + +@app.cell +def __(client, ell, mo): + @ell.simple("gpt-4o-mini-2024-07-18", client=client) + def _my_model(prompt): + """You are an annoying little brother, whatever I say, be sassy with your response""" + return prompt + + + mo.ui.chat( + mo.ai.models.simple(_my_model), + prompts=[ + "Hello", + "How are you?", + "I'm doing great, how about you?", + ], + ) + return + + +@app.cell +def __(): + import marimo as mo + return (mo,) + + +if __name__ == "__main__": + app.run() diff --git a/frontend/src/components/ui/input.tsx b/frontend/src/components/ui/input.tsx index dd30c2b3514..d2671a097ee 100644 --- a/frontend/src/components/ui/input.tsx +++ b/frontend/src/components/ui/input.tsx @@ -12,6 +12,7 @@ import { SearchIcon, XIcon } from "lucide-react"; import { useControllableState } from "@radix-ui/react-use-controllable-state"; export type InputProps = React.InputHTMLAttributes & { + rootClassName?: string; icon?: React.ReactNode; endAdornment?: React.ReactNode; }; @@ -25,7 +26,7 @@ const Input = React.forwardRef( } return ( -
+
{icon && (
{icon} diff --git a/frontend/src/plugins/impl/chat/ChatPlugin.tsx b/frontend/src/plugins/impl/chat/ChatPlugin.tsx new file mode 100644 index 00000000000..a4ba1335fe9 --- /dev/null +++ b/frontend/src/plugins/impl/chat/ChatPlugin.tsx @@ -0,0 +1,74 @@ +/* Copyright 2024 Marimo. All rights reserved. */ +import { z } from "zod"; +import { createPlugin } from "@/plugins/core/builder"; +import { rpc } from "@/plugins/core/rpc"; +import { TooltipProvider } from "@/components/ui/tooltip"; +import { Chatbot } from "./chat-ui"; +import type { ChatMessage, SendMessageRequest } from "./types"; +import { Arrays } from "@/utils/arrays"; + +// eslint-disable-next-line @typescript-eslint/consistent-type-definitions +type PluginFunctions = { + get_chat_history: () => Promise<{ messages: ChatMessage[] }>; + send_prompt: (req: SendMessageRequest) => Promise; +}; + +export const ChatPlugin = createPlugin("marimo-chatbot") + .withData( + z.object({ + prompts: z.array(z.string()).default(Arrays.EMPTY), + showConfigurationControls: z.boolean(), + config: z.object({ + maxTokens: z.number().default(100), + temperature: z.number().default(0.5), + topP: z.number().default(1), + topK: z.number().default(40), + frequencyPenalty: z.number().default(0), + presencePenalty: z.number().default(0), + }), + }), + ) + .withFunctions({ + get_chat_history: rpc.input(z.object({})).output( + z.object({ + messages: z.array( + z.object({ + role: z.enum(["system", "user", "assistant"]), + content: z.string(), + }), + ), + }), + ), + send_prompt: rpc + .input( + z.object({ + messages: z.array( + z.object({ + role: z.enum(["system", "user", "assistant"]), + content: z.string(), + }), + ), + config: z.object({ + max_tokens: z.number(), + temperature: z.number(), + top_p: z.number(), + top_k: z.number(), + frequency_penalty: z.number(), + presence_penalty: z.number(), + }), + }), + ) + .output(z.string()), + }) + .renderer((props) => ( + + + + )); diff --git a/frontend/src/plugins/impl/chat/chat-ui.tsx b/frontend/src/plugins/impl/chat/chat-ui.tsx new file mode 100644 index 00000000000..e884605b6d6 --- /dev/null +++ b/frontend/src/plugins/impl/chat/chat-ui.tsx @@ -0,0 +1,494 @@ +/* Copyright 2024 Marimo. All rights reserved. */ +import { Spinner } from "@/components/icons/spinner"; +import { Logger } from "@/utils/Logger"; +import { type Message, useChat } from "ai/react"; +import React, { useEffect } from "react"; +import type { ChatMessage, ChatConfig, SendMessageRequest } from "./types"; +import { ErrorBanner } from "../common/error-banner"; +import { Button } from "@/components/ui/button"; +import { + BotMessageSquareIcon, + ClipboardIcon, + HelpCircleIcon, + SendIcon, + Trash2Icon, +} from "lucide-react"; +import { cn } from "@/utils/cn"; +import { toast } from "@/components/ui/use-toast"; +import { useState } from "react"; +import { + Popover, + PopoverContent, + PopoverTrigger, +} from "@/components/ui/popover"; +import { Label } from "@/components/ui/label"; +import { SettingsIcon } from "lucide-react"; +import { NumberField } from "@/components/ui/number-field"; +import { Objects } from "@/utils/objects"; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuTrigger, +} from "@/components/ui/dropdown-menu"; +import { Tooltip } from "@/components/ui/tooltip"; +import { startCase } from "lodash-es"; +import { ChatBubbleIcon } from "@radix-ui/react-icons"; +import { renderHTML } from "@/plugins/core/RenderHTML"; +import { Input } from "@/components/ui/input"; +import { PopoverAnchor } from "@radix-ui/react-popover"; + +interface Props { + prompts: string[]; + config: ChatConfig; + showConfigurationControls: boolean; + sendPrompt(req: SendMessageRequest): Promise; + value: ChatMessage[]; + setValue: (messages: ChatMessage[]) => void; +} + +export const Chatbot: React.FC = (props) => { + const inputRef = React.useRef(null); + const [config, setConfig] = useState(props.config); + + const { + messages, + setMessages, + input, + setInput, + handleInputChange, + handleSubmit, + isLoading, + stop, + error, + reload, + } = useChat({ + keepLastMessageOnError: true, + streamProtocol: "text", + fetch: async (_url, request) => { + const body = JSON.parse(request?.body as string) as { + messages: ChatMessage[]; + }; + try { + const response = await props.sendPrompt({ + ...body, + config: { + max_tokens: config.maxTokens, + temperature: config.temperature, + top_p: config.topP, + top_k: config.topK, + frequency_penalty: config.frequencyPenalty, + presence_penalty: config.presencePenalty, + }, + }); + return new Response(response); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (error: any) { + // HACK: strip the error message to clean up the response + const strippedError = error.message + .split("failed with exception ") + .pop(); + return new Response(strippedError, { status: 400 }); + } + }, + onFinish: (message, { usage, finishReason }) => { + Logger.debug("Finished streaming message:", message); + Logger.debug("Token usage:", usage); + Logger.debug("Finish reason:", finishReason); + }, + onError: (error) => { + Logger.error("An error occurred:", error); + }, + onResponse: (response) => { + Logger.debug("Received HTTP response from server:", response); + }, + }); + + const handleDelete = (id: string) => { + setMessages(messages.filter((message) => message.id !== id)); + }; + + const renderMessage = (message: Message) => { + return message.role === "assistant" + ? renderHTML({ html: message.content }) + : message.content; + }; + + return ( +
+
+ +
+
+ {messages.length === 0 && ( +
+ +

No messages yet

+

+ Start a conversation by typing a message below. +

+
+ )} + {messages.map((message) => ( +
+
+

{renderMessage(message)}

+
+
+ + +
+
+ ))} +
+ + {isLoading && ( +
+ + +
+ )} + + {error && ( +
+ + +
+ )} + +
+ {props.showConfigurationControls && ( + + )} + {props.prompts.length > 0 && ( + { + setInput(prompt); + requestAnimationFrame(() => { + inputRef.current?.focus(); + inputRef.current?.setSelectionRange( + prompt.length, + prompt.length, + ); + }); + }} + /> + )} + + + +
+ ); +}; + +const configDescriptions: Record< + keyof ChatConfig, + { min: number; max: number; description: string; step?: number } +> = { + maxTokens: { + min: 1, + max: 4096, + description: "Maximum number of tokens to generate", + }, + temperature: { + min: 0, + max: 2, + step: 0.1, + description: "Controls randomness (0: deterministic, 2: very random)", + }, + topP: { + min: 0, + max: 1, + step: 0.1, + description: "Nucleus sampling: probability mass to consider", + }, + topK: { + min: 1, + max: 100, + description: + "Top-k sampling: number of highest probability tokens to consider", + }, + frequencyPenalty: { + min: -2, + max: 2, + description: "Penalizes frequent tokens (-2: favor, 2: avoid)", + }, + presencePenalty: { + min: -2, + max: 2, + description: "Penalizes new tokens (-2: favor, 2: avoid)", + }, +}; + +const ConfigPopup: React.FC<{ + config: ChatConfig; + onChange: (newConfig: ChatConfig) => void; +}> = ({ config, onChange }) => { + const [localConfig, setLocalConfig] = useState(config); + const [open, setOpen] = useState(false); + + const handleChange = (key: keyof ChatConfig, value: number) => { + const { min, max } = configDescriptions[key]; + const clampedValue = Math.max(min, Math.min(max, value)); + const newConfig = { ...localConfig, [key]: clampedValue }; + setLocalConfig(newConfig); + onChange(newConfig); + }; + + const handleKeyDown = (event: React.KeyboardEvent) => { + if (event.key === "Enter") { + event.preventDefault(); + setOpen(false); + } + }; + + return ( + + + + + + + +
+

Configuration

+ {Objects.entries(localConfig).map(([key, value]) => ( +
+
+ } + > + + + + handleChange(key, num)} + onKeyDown={handleKeyDown} + className="col-span-3" + /> +
+ ))} +
+ + + ); +}; + +const PromptsPopover: React.FC<{ + prompts: string[]; + onSelect: (prompt: string) => void; +}> = ({ prompts, onSelect }) => { + const [isPopoverOpen, setIsPopoverOpen] = useState(false); + const [selectedPrompt, setSelectedPrompt] = useState(""); + + const handleSelection = (prompt: string) => { + const variableRegex = /{{(\w+)}}/g; + const matches = [...prompt.matchAll(variableRegex)]; + + if (matches.length > 0) { + setSelectedPrompt(prompt); + setIsPopoverOpen(true); + } else { + onSelect(prompt); + } + }; + + return ( + + + + + + + + + e.preventDefault()} + className="w-64 max-h-96 overflow-y-auto" + > + {prompts.map((prompt, index) => ( + handleSelection(prompt)} + className="whitespace-normal text-left" + > + {prompt} + + ))} + + + + + + setIsPopoverOpen(false)} + onSelect={onSelect} + /> + + + ); +}; + +const PromptVariablesForm: React.FC<{ + prompt: string; + onClose: () => void; + onSelect: (prompt: string) => void; +}> = ({ prompt, onClose, onSelect }) => { + const [variables, setVariables] = useState<{ [key: string]: string }>({}); + + useEffect(() => { + const variableRegex = /{{(\w+)}}/g; + const matches = [...prompt.matchAll(variableRegex)]; + const initialVariables = matches.reduce<{ [key: string]: string }>( + (acc, match) => { + acc[match[1]] = ""; + return acc; + }, + {}, + ); + setVariables(initialVariables); + }, [prompt]); + + const handleVariableChange = (variable: string, value: string) => { + setVariables((prev) => ({ ...prev, [variable]: value })); + }; + + const replacedPrompt = prompt.replaceAll( + /{{(\w+)}}/g, + (_, key) => variables[key] || `{{${key}}}`, + ); + const isSubmitDisabled = Object.values(variables).some( + (value) => value == null || value.trim() === "", + ); + + const handleSubmit = () => { + onSelect(replacedPrompt); + onClose(); + }; + + return ( +
+ {Object.entries(variables).map(([key, value], index) => ( +
+ + handleVariableChange(key, e.target.value)} + rootClassName="col-span-3 w-full" + className="m-0" + placeholder={`Enter value for ${key}`} + autoFocus={index === 0} + onKeyDown={(e) => { + if (e.key === "Enter" && !isSubmitDisabled) { + handleSubmit(); + } + }} + /> +
+ ))} +
+
{replacedPrompt}
+
+ +
+ ); +}; diff --git a/frontend/src/plugins/impl/chat/types.ts b/frontend/src/plugins/impl/chat/types.ts new file mode 100644 index 00000000000..eb8f53f93f0 --- /dev/null +++ b/frontend/src/plugins/impl/chat/types.ts @@ -0,0 +1,26 @@ +/* Copyright 2024 Marimo. All rights reserved. */ +export interface ChatMessage { + role: "system" | "user" | "assistant"; + content: string; +} + +export interface SendMessageRequest { + messages: ChatMessage[]; + config: { + max_tokens?: number; + temperature?: number; + top_p?: number; + top_k?: number; + frequency_penalty?: number; + presence_penalty?: number; + }; +} + +export interface ChatConfig { + maxTokens: number; + temperature: number; + topP: number; + topK: number; + frequencyPenalty: number; + presencePenalty: number; +} diff --git a/frontend/src/plugins/plugins.ts b/frontend/src/plugins/plugins.ts index 80d158b7cac..9ffa25a491a 100644 --- a/frontend/src/plugins/plugins.ts +++ b/frontend/src/plugins/plugins.ts @@ -46,6 +46,7 @@ import { RoutesPlugin } from "./layout/RoutesPlugin"; import { DateTimePickerPlugin } from "./impl/DateTimePickerPlugin"; import { DateRangePickerPlugin } from "./impl/DateRangePlugin"; import { MimeRendererPlugin } from "./layout/MimeRenderPlugin"; +import { ChatPlugin } from "./impl/chat/ChatPlugin"; // List of UI plugins export const UI_PLUGINS: Array> = [ @@ -74,6 +75,7 @@ export const UI_PLUGINS: Array> = [ new TextInputPlugin(), new VegaPlugin(), new PlotlyPlugin(), + ChatPlugin, DataExplorerPlugin, DataFramePlugin, LazyPlugin, diff --git a/marimo/__init__.py b/marimo/__init__.py index ecf252d8d3f..0f1f272c5e0 100644 --- a/marimo/__init__.py +++ b/marimo/__init__.py @@ -17,19 +17,24 @@ from __future__ import annotations __all__ = [ + # Core API "App", "Cell", - "MarimoStopError", "create_asgi_app", "MarimoIslandGenerator", + "MarimoStopError", + # Other namespaces + "ai", + "ui", + # Application elements "accordion", - "carousel", "app_meta", "as_html", "audio", "callout", - "capture_stdout", "capture_stderr", + "capture_stdout", + "carousel", "center", "cli_args", "defs", @@ -47,27 +52,26 @@ "mpl", "nav_menu", "output", - "plain", - "plain_text", "pdf", + "plain_text", + "plain", "query_params", "redirect_stderr", "redirect_stdout", "refs", "right", - "running_in_notebook", "routes", + "running_in_notebook", "show_code", "sidebar", + "sql", "stat", "state", "status", "stop", - "sql", "style", "tabs", "tree", - "ui", "video", "vstack", ] @@ -82,7 +86,7 @@ from marimo._output.justify import center, left, right from marimo._output.md import md from marimo._output.show_code import show_code -from marimo._plugins import ui +from marimo._plugins import ai, ui from marimo._plugins.stateless import mpl, status from marimo._plugins.stateless.accordion import accordion from marimo._plugins.stateless.audio import audio diff --git a/marimo/_dependencies/dependencies.py b/marimo/_dependencies/dependencies.py index c58a06ac078..a5810a4aca2 100644 --- a/marimo/_dependencies/dependencies.py +++ b/marimo/_dependencies/dependencies.py @@ -160,6 +160,7 @@ class DependencyManager: black = Dependency("black") geopandas = Dependency("geopandas") opentelemetry = Dependency("opentelemetry") + anthropic = Dependency("anthropic") @staticmethod def has(pkg: str) -> bool: diff --git a/marimo/_plugins/ai/__init__.py b/marimo/_plugins/ai/__init__.py new file mode 100644 index 00000000000..1ea6d3eab92 --- /dev/null +++ b/marimo/_plugins/ai/__init__.py @@ -0,0 +1,10 @@ +# Copyright 2024 Marimo. All rights reserved. +"""AI utilities.""" + +from __future__ import annotations + +__all__ = [ + "models", +] + +from marimo._plugins.ui._impl.chat import models diff --git a/marimo/_plugins/ui/__init__.py b/marimo/_plugins/ui/__init__.py index fbf8ee5a322..a4289c4731f 100644 --- a/marimo/_plugins/ui/__init__.py +++ b/marimo/_plugins/ui/__init__.py @@ -8,22 +8,23 @@ __all__ = [ "altair_chart", + "anywidget", "array", "batch", "button", + "chat", "checkbox", "code_editor", "data_explorer", - "date", + "dataframe", "date_range", + "date", "datetime", - "dataframe", "dictionary", "dropdown", - "file", "file_browser", + "file", "form", - "anywidget", "microphone", "multiselect", "number", @@ -43,6 +44,7 @@ from marimo._plugins.ui._impl.altair_chart import altair_chart from marimo._plugins.ui._impl.array import array from marimo._plugins.ui._impl.batch import batch +from marimo._plugins.ui._impl.chat.chat import chat from marimo._plugins.ui._impl.data_explorer import data_explorer from marimo._plugins.ui._impl.dataframes.dataframe import dataframe from marimo._plugins.ui._impl.dates import ( diff --git a/marimo/_plugins/ui/_impl/chat/chat.py b/marimo/_plugins/ui/_impl/chat/chat.py new file mode 100644 index 00000000000..94534c1656d --- /dev/null +++ b/marimo/_plugins/ui/_impl/chat/chat.py @@ -0,0 +1,168 @@ +# Copyright 2024 Marimo. All rights reserved. +from __future__ import annotations + +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, Final, List, Optional, cast + +from marimo._output.formatting import as_html +from marimo._output.rich_help import mddoc +from marimo._plugins.core.web_component import JSONType +from marimo._plugins.ui._core.ui_element import UIElement +from marimo._plugins.ui._impl.chat.types import ( + ChatMessage, + ChatModelConfig, + ChatModelConfigDict, +) +from marimo._plugins.ui._impl.chat.utils import from_chat_message_dict +from marimo._runtime.functions import EmptyArgs, Function + + +@dataclass +class SendMessageRequest: + messages: List[ChatMessage] + config: ChatModelConfig + + +@dataclass +class GetChatHistoryResponse: + messages: List[ChatMessage] + + +@mddoc +class chat(UIElement[Dict[str, Any], List[ChatMessage]]): + """ + A chatbot UI element for interactive conversations. + + **Example - Using a custom model.** + + You can define a custom chat model Callable that takes in + the history of messages and configuration. + + The response can be an object, a marimo UI element, or plain text. + + ```python + def my_rag_model(messages, config): + question = messages[-1].content + docs = find_docs(question) + prompt = template(question, docs, messages) + response = query(prompt) + if is_dataset(response): + return dataset_to_chart(response) + return response + + + chat = mo.ui.chat(my_rag_model) + ``` + + **Example - Using a built-in model.** + + You can use a built-in model from the `mo.ai` module. + + ```python + chat = mo.ui.chat( + mo.ai.openai( + "gpt-4o", + system_message="You are a helpful assistant.", + ), + ) + ``` + + **Attributes.** + + - `value`: the current chat history + + **Initialization Args.** + + - `model`: (Callable[[List[ChatMessage], ChatModelConfig], object]) a + callable that takes in the chat history and returns a response + - `prompts`: optional list of prompts to start the conversation + - `on_message`: optional callback function to handle new messages + - `show_configuration_controls`: whether to show the configuration controls + - `config`: optional ChatModelConfigDict to override the default + configuration keys include: + - `max_tokens` + - `temperature` + - `top_p` + - `top_k` + - `frequency_penalty` + - `presence_penalty` + """ + + _name: Final[str] = "marimo-chatbot" + + def __init__( + self, + model: Callable[[List[ChatMessage], ChatModelConfig], object], + *, + prompts: Optional[List[str]] = None, + on_message: Optional[Callable[[List[ChatMessage]], None]] = None, + show_configuration_controls: bool = False, + config: Optional[ChatModelConfigDict] = None, + ) -> None: + self._model = model + self._chat_history: List[ChatMessage] = [] + + super().__init__( + component_name=chat._name, + initial_value={"messages": self._chat_history}, + on_change=on_message, + label="", + args={ + "prompts": prompts, + "show-configuration-controls": show_configuration_controls, + "config": cast(JSONType, config or {}), + }, + functions=( + Function( + name="get_chat_history", + arg_cls=EmptyArgs, + function=self._get_chat_history, + ), + Function( + name="send_prompt", + arg_cls=SendMessageRequest, + function=self._send_prompt, + ), + ), + ) + + def _get_chat_history(self, _args: EmptyArgs) -> GetChatHistoryResponse: + return GetChatHistoryResponse(messages=self._chat_history) + + def _send_prompt(self, args: SendMessageRequest) -> str: + messages = args.messages + + # If the model is a callable that takes a single argument, + # call it with just the messages. + response: object + if ( + callable(self._model) + and not isinstance(self._model, type) + and len(inspect.signature(self._model).parameters) == 1 + ): + response = self._model(messages) # type: ignore + else: + response = self._model(messages, args.config) + + content = ( + as_html(response).text # convert to html if not a string + if not isinstance(response, str) + else response + ) + self._chat_history = messages + [ + ChatMessage(role="assistant", content=content) + ] + + self._value = self._chat_history + if self._on_change: + self._on_change(self._value) + + return content + + def _convert_value(self, value: Dict[str, Any]) -> List[ChatMessage]: + if not isinstance(value, dict) or "messages" not in value: + raise ValueError("Invalid chat history format") + + messages = value["messages"] + return [from_chat_message_dict(msg) for msg in messages] diff --git a/marimo/_plugins/ui/_impl/chat/convert.py b/marimo/_plugins/ui/_impl/chat/convert.py new file mode 100644 index 00000000000..91a2d41a5f0 --- /dev/null +++ b/marimo/_plugins/ui/_impl/chat/convert.py @@ -0,0 +1,62 @@ +# Copyright 2024 Marimo. All rights reserved. +from __future__ import annotations + +from typing import Any, Dict, List + +from marimo._plugins.ui._impl.chat.types import ChatMessage + + +def convert_to_openai_messages( + messages: List[ChatMessage], +) -> List[Dict[Any, Any]]: + openai_messages: List[Dict[Any, Any]] = [] + + for message in messages: + parts: List[Dict[Any, Any]] = [] + + parts.append({"type": "text", "text": message.content}) + + if message.attachments: + for attachment in message.attachments: + if attachment.content_type.startswith("image"): + parts.append( + { + "type": "image_url", + "image_url": {"url": attachment.url}, + } + ) + + elif attachment.content_type.startswith("text"): + parts.append({"type": "text", "text": attachment.url}) + + openai_messages.append({"role": message.role, "content": parts}) + + return openai_messages + + +def convert_to_anthropic_messages( + messages: List[ChatMessage], +) -> List[Dict[Any, Any]]: + anthropic_messages: List[Dict[Any, Any]] = [] + + for message in messages: + parts: List[Dict[Any, Any]] = [] + + parts.append({"type": "text", "text": message.content}) + + if message.attachments: + for attachment in message.attachments: + if attachment.content_type.startswith("image"): + parts.append( + { + "type": "image_url", + "image_url": {"url": attachment.url}, + } + ) + + elif attachment.content_type.startswith("text"): + parts.append({"type": "text", "text": attachment.url}) + + anthropic_messages.append({"role": message.role, "content": parts}) + + return anthropic_messages diff --git a/marimo/_plugins/ui/_impl/chat/models.py b/marimo/_plugins/ui/_impl/chat/models.py new file mode 100644 index 00000000000..2f60d5480b5 --- /dev/null +++ b/marimo/_plugins/ui/_impl/chat/models.py @@ -0,0 +1,229 @@ +# Copyright 2024 Marimo. All rights reserved. +from __future__ import annotations + +import os +from typing import Callable, List, Optional, cast + +from marimo._dependencies.dependencies import DependencyManager +from marimo._plugins.ui._impl.chat.convert import ( + convert_to_anthropic_messages, + convert_to_openai_messages, +) +from marimo._plugins.ui._impl.chat.types import ( + ChatMessage, + ChatModel, + ChatModelConfig, +) + +DEFAULT_SYSTEM_MESSAGE = ( + "You are a helpful assistant specializing in data science." +) + + +class simple(ChatModel): + """ + Convenience class for wrapping a ChatModel or callable to + take a single prompt + + **Args:** + + - delegate (Callable[[str], str]): A callable that takes a + single prompt and returns a response + """ + + def __init__(self, delegate: Callable[[str], object]): + self.delegate = delegate + + def __call__( + self, messages: List[ChatMessage], config: ChatModelConfig + ) -> object: + del config + prompt = messages[-1].content + return self.delegate(prompt) + + +class openai(ChatModel): + """ + OpenAI ChatModel + + **Args:** + + - model (str): The model to use. + Can be found on the [OpenAI models page](https://platform.openai.com/docs/models) + - system_message (str): The system message to use + - api_key (Optional[str]): The API key to use. + If not provided, the API key will be retrieved + from the OPENAI_API_KEY environment variable or the user's config. + - base_url (Optional[str]): The base URL to use + """ + + def __init__( + self, + model: str, + *, + system_message: str = DEFAULT_SYSTEM_MESSAGE, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + ): + self.model = model + self.system_message = system_message + self.api_key = api_key + self.base_url = base_url + + @property + def _require_api_key(self) -> str: + # If the api key is provided, use it + if self.api_key is not None: + return self.api_key + + # Then check the environment variable + env_key = os.environ.get("OPENAI_API_KEY") + if env_key is not None: + return env_key + + # Then check the user's config + try: + from marimo._runtime.context.types import get_context + + api_key = get_context().user_config["ai"]["open_ai"]["api_key"] + if api_key: + return api_key + except Exception: + pass + + raise ValueError( + "openai api key not provided. Pass it as an argument or " + "set OPENAI_API_KEY as an environment variable" + ) + + def __call__( + self, messages: List[ChatMessage], config: ChatModelConfig + ) -> object: + DependencyManager.openai.require( + "chat model requires openai. `pip install openai`" + ) + from openai import OpenAI # type: ignore[import-not-found] + from openai.types.chat import ( # type: ignore[import-not-found] + ChatCompletionMessageParam, + ) + + client = OpenAI( + api_key=self._require_api_key, + base_url=self.base_url, + ) + + openai_messages = convert_to_openai_messages( + [ChatMessage(role="system", content=self.system_message)] + + messages + ) + response = client.chat.completions.create( + model=self.model, + messages=cast(List[ChatCompletionMessageParam], openai_messages), + max_tokens=config.max_tokens, + temperature=config.temperature, + top_p=config.top_p, + frequency_penalty=config.frequency_penalty, + presence_penalty=config.presence_penalty, + stream=False, + ) + + choice = response.choices[0] + content = choice.message.content + return content or "" + + +class anthropic(ChatModel): + """ + Anthropic ChatModel + + **Args:** + + - model (str): The model to use. + - system_message (str): The system message to use + - api_key (Optional[str]): The API key to use. + If not provided, the API key will be retrieved + from the ANTHROPIC_API_KEY environment variable + or the user's config. + - base_url (Optional[str]): The base URL to use + """ + + def __init__( + self, + model: str, + *, + system_message: str = DEFAULT_SYSTEM_MESSAGE, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + ): + self.model = model + self.system_message = system_message + self.api_key = api_key + self.base_url = base_url + self.system_message = system_message + + @property + def _require_api_key(self) -> str: + # If the api key is provided, use it + if self.api_key is not None: + return self.api_key + + # Then check the user's config + try: + from marimo._runtime.context.types import get_context + + api_key = get_context().user_config["ai"]["anthropic"]["api_key"] + if api_key: + return api_key + except Exception: + pass + + # Then check the environment variable + env_key = os.environ.get("ANTHROPIC_API_KEY") + if env_key is not None: + return env_key + + raise ValueError( + "anthropic api key not provided. Pass it as an argument or " + "set ANTHROPIC_API_KEY as an environment variable" + ) + + def __call__( + self, messages: List[ChatMessage], config: ChatModelConfig + ) -> object: + DependencyManager.anthropic.require( + "chat model requires anthropic. `pip install anthropic`" + ) + from anthropic import ( # type: ignore[import-not-found] + NOT_GIVEN, + Anthropic, + ) + from anthropic.types.message_param import ( # type: ignore[import-not-found] + MessageParam, + ) + + client = Anthropic( + api_key=self._require_api_key, + base_url=self.base_url, + ) + + anthropic_messages = convert_to_anthropic_messages(messages) + response = client.messages.create( + model=self.model, + system=self.system_message, + max_tokens=config.max_tokens or 1000, + messages=cast(List[MessageParam], anthropic_messages), + top_p=config.top_p if config.top_p is not None else NOT_GIVEN, + top_k=config.top_k if config.top_k is not None else NOT_GIVEN, + stream=False, + temperature=config.temperature + if config.temperature is not None + else NOT_GIVEN, + ) + + content = response.content + if len(content) > 0: + if content[0].type == "text": + return content[0].text + elif content[0].type == "tool_use": + return content + return "" diff --git a/marimo/_plugins/ui/_impl/chat/types.py b/marimo/_plugins/ui/_impl/chat/types.py new file mode 100644 index 00000000000..0ae4a3e26b8 --- /dev/null +++ b/marimo/_plugins/ui/_impl/chat/types.py @@ -0,0 +1,79 @@ +# Copyright 2024 Marimo. All rights reserved. +from __future__ import annotations + +import abc +from dataclasses import dataclass +from typing import List, Literal, Optional, TypedDict + + +class ChatAttachmentDict(TypedDict): + name: str + content_type: str + url: str + + +class ChatMessageDict(TypedDict): + role: Literal["user", "assistant", "system"] + content: str + attachments: Optional[List[ChatAttachmentDict]] + + +class ChatModelConfigDict(TypedDict, total=False): + max_tokens: Optional[int] + temperature: Optional[float] + top_p: Optional[float] + top_k: Optional[int] + frequency_penalty: Optional[float] + presence_penalty: Optional[float] + + +# NOTE: The following classes are public API. +# Any changes must be backwards compatible. + + +@dataclass +class ChatAttachment: + # The name of the attachment, usually the file name. + name: str + + # A string indicating the [media type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type). + # By default, it's extracted from the pathname's extension. + content_type: str + + # The URL of the attachment. It can either be a URL to a hosted file or a + # [Data URL](https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/Data_URLs). + url: str + + +@dataclass +class ChatMessage: + """ + A message in a chat. + """ + + # The role of the message. + role: Literal["user", "assistant", "system"] + + # The content of the message. + content: str + + # Optional attachments to the message. + attachments: Optional[List[ChatAttachment]] = None + + +@dataclass +class ChatModelConfig: + max_tokens: Optional[int] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + top_k: Optional[int] = None + frequency_penalty: Optional[float] = None + presence_penalty: Optional[float] = None + + +class ChatModel(abc.ABC): + @abc.abstractmethod + def __call__( + self, messages: List[ChatMessage], config: ChatModelConfig + ) -> object: + pass diff --git a/marimo/_plugins/ui/_impl/chat/utils.py b/marimo/_plugins/ui/_impl/chat/utils.py new file mode 100644 index 00000000000..ffe1fd8d407 --- /dev/null +++ b/marimo/_plugins/ui/_impl/chat/utils.py @@ -0,0 +1,35 @@ +# Copyright 2024 Marimo. All rights reserved. +from __future__ import annotations + +from typing import List, Optional + +from marimo._plugins.ui._impl.chat.types import ( + ChatAttachment, + ChatMessage, + ChatMessageDict, +) + + +def from_chat_message_dict(d: ChatMessageDict) -> ChatMessage: + if isinstance(d, ChatMessage): + return d + + attachments_dict = d.get("attachments", None) + attachments: Optional[List[ChatAttachment]] = None + if attachments_dict is not None: + attachments = [ + ChatAttachment( + name=attachment["name"], + content_type=attachment["content_type"], + url=attachment["url"], + ) + for attachment in attachments_dict + ] + else: + attachments = None + + return ChatMessage( + role=d["role"], + content=d["content"], + attachments=attachments, + ) diff --git a/marimo/_plugins/validators.py b/marimo/_plugins/validators.py index 08e2833db26..ed7a5bc5862 100644 --- a/marimo/_plugins/validators.py +++ b/marimo/_plugins/validators.py @@ -1,3 +1,4 @@ +# Copyright 2024 Marimo. All rights reserved. from __future__ import annotations import warnings diff --git a/marimo/_runtime/packages/module_name_to_pypi_name.py b/marimo/_runtime/packages/module_name_to_pypi_name.py index f733fd1cd1a..68e1f4fa906 100644 --- a/marimo/_runtime/packages/module_name_to_pypi_name.py +++ b/marimo/_runtime/packages/module_name_to_pypi_name.py @@ -413,6 +413,7 @@ def module_name_to_pypi_name() -> dict[str, str]: "elasticluster": "azure-elasticluster-current", "elftools": "pyelftools", "elixir": "Elixir", + "ell": "ell-ai", "emlib": "empy", "enchant": "pyenchant", "encutils": "cssutils", diff --git a/marimo/_smoke_tests/chat/chatbot.py b/marimo/_smoke_tests/chat/chatbot.py new file mode 100644 index 00000000000..ab7efa38bca --- /dev/null +++ b/marimo/_smoke_tests/chat/chatbot.py @@ -0,0 +1,229 @@ +# /// script +# requires-python = ">=3.11" +# dependencies = [ +# "ell-ai==0.0.12", +# "marimo", +# "openai==1.50.1", +# "pydantic==2.9.2", +# "vega-datasets==0.9.0", +# ] +# /// + +import marimo + +__generated_with = "0.8.20" +app = marimo.App(width="medium") + + +@app.cell +def __(mo): + mo.md(r"""# Built-in chatbots""") + return + + +@app.cell +def __(mo): + mo.md(r"""## OpenAI""") + return + + +@app.cell +def __(mo): + mo.ui.chat( + mo.ai.models.openai( + "gpt-4-turbo", system_message="You are a helpful data scientist" + ), + show_configuration_controls=True, + prompts=[ + "Tell me a joke", + "What is the meaning of life?", + "What is 2 + {{number}}", + ], + ) + return + + +@app.cell +def __(mo): + mo.md(r"""## Anthropic""") + return + + +@app.cell +def __(mo): + mo.ui.chat( + mo.ai.models.anthropic("claude-3-5-sonnet-20240620"), + show_configuration_controls=True, + prompts=[ + "Tell me a joke", + "What is the meaning of life?", + "What is 2 + {{number}}", + ], + ) + return + + +@app.cell +def __(): + import marimo as mo + return (mo,) + + +@app.cell +def __(mo): + mo.md(r"""# Custom chatbots""") + return + + +@app.cell(hide_code=True) +def __(mo): + import os + + os_key = os.environ.get("OPENAI_API_KEY") + input_key = mo.ui.text(label="OpenAI API key", kind="password") + input_key if not os_key else None + return input_key, os, os_key + + +@app.cell +def __(input_key, os_key): + openai_key = os_key or input_key.value + return (openai_key,) + + +@app.cell(hide_code=True) +def __(mo, openai_key): + # Initialize a client + mo.stop( + not openai_key, + "Please set the OPENAI_API_KEY environment variable or provide it in the input field", + ) + + import ell + import openai + + # Create an openai client + client = openai.Client(api_key=openai_key) + return client, ell, openai + + +@app.cell +def __(mo): + mo.md(r"""## Simple""") + return + + +@app.cell +def __(client, ell, mo): + @ell.simple("gpt-4o-mini-2024-07-18", client=client) + def _my_model(prompt): + """You are an annoying little brother, whatever I say, be sassy with your response""" + return prompt + + + mo.ui.chat(mo.ai.models.simple(_my_model)) + return + + +@app.cell +def __(mo): + mo.md(r"""## Complex""") + return + + +@app.cell +def __(): + # Grab a dataset for the chatbot conversation, we will use the cars dataset + + from vega_datasets import data + + cars = data.cars() + return cars, data + + +@app.cell +def __(cars, client, ell): + from pydantic import BaseModel, Field + + + class PromptsResponse(BaseModel): + prompts: list[str] = Field( + description="A list of prompts to use for the chatbot" + ) + + + @ell.complex( + "gpt-4o-mini-2024-07-18", client=client, response_format=PromptsResponse + ) + def get_sample_prompts(df): + """You are a helpful data scientist""" + return ( + "Given the following schema of this dataset, " + f"what would be three interesting questions to ask? \n{df.dtypes}" + ) + + + def my_complex_model(messages, config): + schema = cars.dtypes + + # This doesn't need to be ell or any model provider + # You can use your own model here. + @ell.complex(model="gpt-4o", temperature=0.7) + def chat_bot(message_history): + return [ + ell.system(f""" + You are a helpful data scientist chatbot. + + I would like you to analyze this dataset. You must only ask follow-up questions or return a single valid JSON of a vega-lite specification so that it can be charted. + + Here is the dataset schema {schema}. + + If you are returning JSON, only return the json without any explanation. And don't wrap in backquotes or code fences + """), + ] + message_history + + # History + message_history = [ + ell.user(message.content) + if message.role == "user" + else ell.assistant(message.content) + for message in messages + ] + # Prompt + # message_history.append(ell.user(prompt)) + + # Go! + response = chat_bot(message_history).text + if response.startswith("{"): + import altair as alt + import json + + as_dict = json.loads(response) + # add our cars dataset + print(as_dict) + as_dict["data"] = {"values": cars.dropna().to_dict(orient="records")} + if "datasets" in as_dict: + del as_dict["datasets"] + return alt.Chart.from_dict(as_dict) + return response + return ( + BaseModel, + Field, + PromptsResponse, + get_sample_prompts, + my_complex_model, + ) + + +@app.cell +def __(cars, get_sample_prompts, mo, my_complex_model): + prompts = get_sample_prompts(cars).parsed.prompts + mo.ui.chat( + my_complex_model, + prompts=prompts, + ) + return (prompts,) + + +if __name__ == "__main__": + app.run() diff --git a/tests/_plugins/ui/_impl/chat/test_chat.py b/tests/_plugins/ui/_impl/chat/test_chat.py new file mode 100644 index 00000000000..9b55030fe37 --- /dev/null +++ b/tests/_plugins/ui/_impl/chat/test_chat.py @@ -0,0 +1,166 @@ +# Copyright 2024 Marimo. All rights reserved. +from __future__ import annotations + +from typing import Dict, List + +import pytest + +from marimo._plugins import ui +from marimo._plugins.ui._impl.chat.chat import SendMessageRequest +from marimo._plugins.ui._impl.chat.types import ( + ChatMessage, + ChatModelConfig, + ChatModelConfigDict, +) +from marimo._runtime.functions import EmptyArgs + + +def test_chat_init(): + def mock_model( + messages: List[ChatMessage], config: ChatModelConfig + ) -> str: + del messages, config + return "Mock response" + + chat = ui.chat(mock_model) + assert chat._model == mock_model + assert chat._chat_history == [] + assert chat.value == [] + + +def test_chat_with_prompts(): + def mock_model( + messages: List[ChatMessage], config: ChatModelConfig + ) -> str: + del messages, config + return "Mock response" + + prompts: List[str] = ["Hello", "How are you?"] + chat = ui.chat(mock_model, prompts=prompts) + assert chat._component_args["prompts"] == prompts + + +def test_chat_with_config(): + def mock_model( + messages: List[ChatMessage], config: ChatModelConfig + ) -> str: + del messages, config + return "Mock response" + + config: ChatModelConfigDict = {"temperature": 0.7, "max_tokens": 100} + chat = ui.chat(mock_model, config=config) + assert chat._component_args["config"] == config + + +def test_chat_send_prompt(): + def mock_model( + messages: List[ChatMessage], config: ChatModelConfig + ) -> str: + del config + return f"Response to: {messages[-1].content}" + + chat = ui.chat(mock_model) + request = SendMessageRequest( + messages=[ChatMessage(role="user", content="Hello")], + config=ChatModelConfig(), + ) + response: str = chat._send_prompt(request) + + assert response == "Response to: Hello" + assert len(chat._chat_history) == 2 + assert chat._chat_history[0].role == "user" + assert chat._chat_history[0].content == "Hello" + assert chat._chat_history[1].role == "assistant" + assert chat._chat_history[1].content == "Response to: Hello" + + +def test_chat_get_history(): + def mock_model( + messages: List[ChatMessage], config: ChatModelConfig + ) -> str: + del messages, config + return "Mock response" + + chat = ui.chat(mock_model) + chat._chat_history = [ + ChatMessage(role="user", content="Hello"), + ChatMessage(role="assistant", content="Hi there!"), + ] + + history = chat._get_chat_history(EmptyArgs()) + assert len(history.messages) == 2 + assert history.messages[0].role == "user" + assert history.messages[0].content == "Hello" + assert history.messages[1].role == "assistant" + assert history.messages[1].content == "Hi there!" + + +def test_chat_convert_value(): + def mock_model( + messages: List[ChatMessage], config: ChatModelConfig + ) -> str: + del messages, config + return "Mock response" + + chat = ui.chat(mock_model) + value: Dict[str, List[Dict[str, str]]] = { + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + } + + converted: List[ChatMessage] = chat._convert_value(value) + assert len(converted) == 2 + assert converted[0].role == "user" + assert converted[0].content == "Hello" + assert converted[1].role == "assistant" + assert converted[1].content == "Hi there!" + + +def test_chat_convert_value_invalid(): + def mock_model( + messages: List[ChatMessage], config: ChatModelConfig + ) -> str: + del messages, config + return "Mock response" + + chat = ui.chat(mock_model) + + with pytest.raises(ValueError, match="Invalid chat history format"): + chat._convert_value({"invalid": "format"}) + + +def test_chat_with_on_message(): + def mock_model( + messages: List[ChatMessage], config: ChatModelConfig + ) -> str: + del messages, config + return "Mock response" + + on_message_called = False + + def on_message(messages: List[ChatMessage]) -> None: + del messages + nonlocal on_message_called + on_message_called = True + + chat = ui.chat(mock_model, on_message=on_message) + request = SendMessageRequest( + messages=[ChatMessage(role="user", content="Hello")], + config=ChatModelConfig(), + ) + chat._send_prompt(request) + + assert on_message_called + + +def test_chat_with_show_configuration_controls(): + def mock_model( + messages: List[ChatMessage], config: ChatModelConfig + ) -> str: + del messages, config + return "Mock response" + + chat = ui.chat(mock_model, show_configuration_controls=True) + assert chat._component_args["show-configuration-controls"] is True diff --git a/tests/_plugins/ui/_impl/chat/test_chat_convert.py b/tests/_plugins/ui/_impl/chat/test_chat_convert.py new file mode 100644 index 00000000000..a7f9010a8ab --- /dev/null +++ b/tests/_plugins/ui/_impl/chat/test_chat_convert.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +from typing import List + +import pytest + +from marimo._plugins.ui._impl.chat.convert import ( + convert_to_anthropic_messages, + convert_to_openai_messages, +) +from marimo._plugins.ui._impl.chat.types import ( + ChatAttachment, + ChatMessage, +) +from marimo._plugins.ui._impl.chat.utils import from_chat_message_dict + + +@pytest.fixture +def sample_messages() -> List[ChatMessage]: + return [ + ChatMessage( + role="user", + content="Hello, I have a question.", + attachments=[ + ChatAttachment( + name="image.png", + content_type="image/png", + url="http://example.com/image.png", + ), + ChatAttachment( + name="text.txt", + content_type="text/plain", + url="http://example.com/text.txt", + ), + ], + ), + ChatMessage( + role="assistant", + content="Sure, I'd be happy to help. What's your question?", + attachments=[], + ), + ] + + +def test_convert_to_openai_messages(sample_messages: List[ChatMessage]): + result = convert_to_openai_messages(sample_messages) + + assert len(result) == 2 + + # Check user message + assert result[0]["role"] == "user" + assert len(result[0]["content"]) == 3 + assert result[0]["content"][0] == { + "type": "text", + "text": "Hello, I have a question.", + } + assert result[0]["content"][1] == { + "type": "image_url", + "image_url": {"url": "http://example.com/image.png"}, + } + assert result[0]["content"][2] == { + "type": "text", + "text": "http://example.com/text.txt", + } + + # Check assistant message + assert result[1]["role"] == "assistant" + assert len(result[1]["content"]) == 1 + assert result[1]["content"][0] == { + "type": "text", + "text": "Sure, I'd be happy to help. What's your question?", + } + + +def test_convert_to_anthropic_messages(sample_messages: List[ChatMessage]): + result = convert_to_anthropic_messages(sample_messages) + + assert len(result) == 2 + + # Check user message + assert result[0]["role"] == "user" + assert len(result[0]["content"]) == 3 + assert result[0]["content"][0] == { + "type": "text", + "text": "Hello, I have a question.", + } + assert result[0]["content"][1] == { + "type": "image_url", + "image_url": {"url": "http://example.com/image.png"}, + } + assert result[0]["content"][2] == { + "type": "text", + "text": "http://example.com/text.txt", + } + + # Check assistant message + assert result[1]["role"] == "assistant" + assert len(result[1]["content"]) == 1 + assert result[1]["content"][0] == { + "type": "text", + "text": "Sure, I'd be happy to help. What's your question?", + } + + +def test_empty_messages(): + empty_messages = [] + assert convert_to_openai_messages(empty_messages) == [] + assert convert_to_anthropic_messages(empty_messages) == [] + + +def test_message_without_attachments(): + messages = [ + ChatMessage( + role="user", content="Just a simple message", attachments=[] + ) + ] + + openai_result = convert_to_openai_messages(messages) + assert len(openai_result) == 1 + assert openai_result[0]["role"] == "user" + assert len(openai_result[0]["content"]) == 1 + assert openai_result[0]["content"][0] == { + "type": "text", + "text": "Just a simple message", + } + + anthropic_result = convert_to_anthropic_messages(messages) + assert len(anthropic_result) == 1 + assert anthropic_result[0]["role"] == "user" + assert len(anthropic_result[0]["content"]) == 1 + assert anthropic_result[0]["content"][0] == { + "type": "text", + "text": "Just a simple message", + } + + +def test_from_chat_message_dict(): + # Test case 1: ChatMessage with attachments + message_dict = { + "role": "user", + "content": "Hello, this is a test message.", + "attachments": [ + { + "name": "test.png", + "content_type": "image/png", + "url": "http://example.com/test.png", + } + ], + } + + result = from_chat_message_dict(message_dict) + + assert isinstance(result, ChatMessage) + assert result.role == "user" + assert result.content == "Hello, this is a test message." + assert len(result.attachments) == 1 + assert isinstance(result.attachments[0], ChatAttachment) + assert result.attachments[0].name == "test.png" + assert result.attachments[0].content_type == "image/png" + assert result.attachments[0].url == "http://example.com/test.png" + + # Test case 2: ChatMessage without attachments + message_dict_no_attachments = { + "role": "assistant", + "content": "This is a response without attachments.", + } + + result_no_attachments = from_chat_message_dict(message_dict_no_attachments) + + assert isinstance(result_no_attachments, ChatMessage) + assert result_no_attachments.role == "assistant" + assert ( + result_no_attachments.content + == "This is a response without attachments." + ) + assert result_no_attachments.attachments is None + + # Test case 3: Input is already a ChatMessage + existing_chat_message = ChatMessage( + role="system", content="System message", attachments=None + ) + + result_existing = from_chat_message_dict(existing_chat_message) + + assert result_existing is existing_chat_message diff --git a/tests/_plugins/ui/_impl/chat/test_chat_model.py b/tests/_plugins/ui/_impl/chat/test_chat_model.py new file mode 100644 index 00000000000..7780ce88396 --- /dev/null +++ b/tests/_plugins/ui/_impl/chat/test_chat_model.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from marimo._plugins.ui._impl.chat.models import simple +from marimo._plugins.ui._impl.chat.types import ChatMessage, ChatModelConfig + + +def test_simple_model(): + model = simple(lambda x: x * 2) + assert ( + model([ChatMessage(role="user", content="hey")], ChatModelConfig()) + == "heyhey" + ) + + assert ( + model( + [ + ChatMessage(role="user", content="hey", attachments=[]), + ChatMessage(role="user", content="goodbye", attachments=[]), + ], + ChatModelConfig(), + ) + == "goodbyegoodbye" + )