From f257b96663fc791a7ebcb511268c046f1f8565ff Mon Sep 17 00:00:00 2001 From: Shreya Shankar Date: Thu, 7 Nov 2024 08:40:52 -0800 Subject: [PATCH 1/6] fix: switch crypto uuid to regular uuid --- website/package-lock.json | 22 ++++++++++++++++++++++ website/package.json | 2 ++ website/src/components/PipelineGui.tsx | 16 +++------------- 3 files changed, 27 insertions(+), 13 deletions(-) diff --git a/website/package-lock.json b/website/package-lock.json index 58cc7248..e7d923e1 100644 --- a/website/package-lock.json +++ b/website/package-lock.json @@ -63,6 +63,7 @@ "style-loader": "^4.0.0", "tailwind-merge": "^2.5.2", "tailwindcss-animate": "^1.0.7", + "uuid": "^11.0.2", "zod": "^3.23.8" }, "devDependencies": { @@ -74,6 +75,7 @@ "@types/react": "^18", "@types/react-beautiful-dnd": "^13.1.8", "@types/react-dom": "^18", + "@types/uuid": "^10.0.0", "eslint": "^8.57.1", "eslint-config-next": "14.2.11", "eslint-plugin-react": "^7.37.2", @@ -3861,6 +3863,13 @@ "resolved": "https://registry.npmjs.org/@types/unist/-/unist-3.0.3.tgz", "integrity": "sha512-ko/gIFJRv177XgZsZcBwnqJN5x/Gien8qNOn0D5bQU/zAzVf9Zt3BlcUiLqhV9y4ARk0GbT3tnUiPNgnTXzc/Q==" }, + "node_modules/@types/uuid": { + "version": "10.0.0", + "resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-10.0.0.tgz", + "integrity": "sha512-7gqG38EyHgyP1S+7+xomFtL+ZNHcKv6DwNaCZmJmo1vgMugyF3TCnXVg4t1uk89mLNwnLtnY3TpOpCOyp1/xHQ==", + "dev": true, + "license": "MIT" + }, "node_modules/@types/yargs": { "version": "17.0.33", "resolved": "https://registry.npmjs.org/@types/yargs/-/yargs-17.0.33.tgz", @@ -11986,6 +11995,19 @@ "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", "integrity": "sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==" }, + "node_modules/uuid": { + "version": "11.0.2", + "resolved": "https://registry.npmjs.org/uuid/-/uuid-11.0.2.tgz", + "integrity": "sha512-14FfcOJmqdjbBPdDjFQyk/SdT4NySW4eM0zcG+HqbHP5jzuH56xO3J1DGhgs/cEMCfwYi3HQI1gnTO62iaG+tQ==", + "funding": [ + "https://github.com/sponsors/broofa", + "https://github.com/sponsors/ctavan" + ], + "license": "MIT", + "bin": { + "uuid": "dist/esm/bin/uuid" + } + }, "node_modules/vfile": { "version": "6.0.3", "resolved": "https://registry.npmjs.org/vfile/-/vfile-6.0.3.tgz", diff --git a/website/package.json b/website/package.json index 4ca7c52c..013cf86b 100644 --- a/website/package.json +++ b/website/package.json @@ -64,6 +64,7 @@ "style-loader": "^4.0.0", "tailwind-merge": "^2.5.2", "tailwindcss-animate": "^1.0.7", + "uuid": "^11.0.2", "zod": "^3.23.8" }, "devDependencies": { @@ -75,6 +76,7 @@ "@types/react": "^18", "@types/react-beautiful-dnd": "^13.1.8", "@types/react-dom": "^18", + "@types/uuid": "^10.0.0", "eslint": "^8.57.1", "eslint-config-next": "14.2.11", "eslint-plugin-react": "^7.37.2", diff --git a/website/src/components/PipelineGui.tsx b/website/src/components/PipelineGui.tsx index 80d23593..9fab7dcb 100644 --- a/website/src/components/PipelineGui.tsx +++ b/website/src/components/PipelineGui.tsx @@ -53,6 +53,7 @@ import { useWebSocket } from "@/contexts/WebSocketContext"; import { Input } from "@/components/ui/input"; import path from "path"; import { schemaDictToItemSet } from "./utils"; +import { v4 as uuidv4 } from "uuid"; const PipelineGUI: React.FC = () => { const fileInputRef = useRef(null); @@ -67,8 +68,6 @@ const PipelineGUI: React.FC = () => { setNumOpRun, currentFile, setCurrentFile, - output, - unsavedChanges, setFiles, setOutput, isLoadingOutputs, @@ -78,11 +77,8 @@ const PipelineGUI: React.FC = () => { defaultModel, setDefaultModel, setTerminalOutput, - saveProgress, - clearPipelineState, optimizerModel, setOptimizerModel, - optimizerProgress, setOptimizerProgress, } = usePipelineContext(); const [isSettingsOpen, setIsSettingsOpen] = useState(false); @@ -98,9 +94,6 @@ const PipelineGUI: React.FC = () => { const { toast } = useToast(); const { connect, sendMessage, lastMessage, readyState, disconnect } = useWebSocket(); - const [runningButtonType, setRunningButtonType] = useState< - "run" | "clear-run" | null - >(null); useEffect(() => { if (lastMessage) { @@ -139,7 +132,7 @@ const PipelineGUI: React.FC = () => { const existingOp = operations.find((op) => op.name === name); return { - id: id || crypto.randomUUID(), + id: id || uuidv4(), llmType: type === "map" || type === "reduce" || @@ -254,7 +247,7 @@ const PipelineGUI: React.FC = () => { } return { - id: id || crypto.randomUUID(), + id: id || uuidv4(), llmType: type === "map" || type === "reduce" || @@ -386,7 +379,6 @@ const PipelineGUI: React.FC = () => { const lastOperation = operations[lastOpIndex]; setOptimizerProgress(null); setIsLoadingOutputs(true); - setRunningButtonType(clear_intermediate ? "clear-run" : "run"); setNumOpRun((prevNum) => { const newNum = prevNum + operations.length; const updatedOperations = operations.map((op, index) => ({ @@ -445,7 +437,6 @@ const PipelineGUI: React.FC = () => { // Close the WebSocket connection disconnect(); setIsLoadingOutputs(false); - setRunningButtonType(null); } }, [ @@ -510,7 +501,6 @@ const PipelineGUI: React.FC = () => { const handleStop = () => { sendMessage("kill"); - setRunningButtonType(null); if (readyState === WebSocket.CLOSED && isLoadingOutputs) { setIsLoadingOutputs(false); From cb778485ea1217af9747a1fb3fa0742e5ef30b65 Mon Sep 17 00:00:00 2001 From: Shreya Shankar Date: Thu, 7 Nov 2024 12:29:26 -0800 Subject: [PATCH 2/6] Provide defaults in the UI chat --- website/src/components/AIChatPanel.tsx | 107 +++++++++++++++++-------- 1 file changed, 75 insertions(+), 32 deletions(-) diff --git a/website/src/components/AIChatPanel.tsx b/website/src/components/AIChatPanel.tsx index a1de978a..6fce6344 100644 --- a/website/src/components/AIChatPanel.tsx +++ b/website/src/components/AIChatPanel.tsx @@ -2,7 +2,7 @@ import React, { useRef, useState, useEffect } from "react"; import { ResizableBox } from "react-resizable"; -import { X } from "lucide-react"; +import { Eraser, RefreshCw, X } from "lucide-react"; import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; import { ScrollArea } from "@/components/ui/scroll-area"; @@ -18,6 +18,12 @@ interface AIChatPanelProps { onClose: () => void; } +const DEFAULT_SUGGESTIONS = [ + "Go over current outputs", + "Help me refine my current operation prompt", + "Am I doing this right?", +]; + const AIChatPanel: React.FC = ({ onClose }) => { const [position, setPosition] = useState({ x: window.innerWidth - 400, @@ -36,6 +42,7 @@ const AIChatPanel: React.FC = ({ onClose }) => { } = useChat({ api: "/api/chat", initialMessages: [], + id: "persistent-chat", }); const { serializeState } = usePipelineContext(); @@ -114,6 +121,10 @@ ${pipelineState}`, handleSubmit(e); }; + const handleClearMessages = () => { + setMessages([]); + }; + return (
- +
+ + +
- {messages - .filter((message) => message.role !== "system") - .map((message, index) => ( -
+ {messages.filter((message) => message.role !== "system").length === + 0 ? ( +
+ {DEFAULT_SUGGESTIONS.map((suggestion, index) => ( + + ))} +
+ ) : ( + messages + .filter((message) => message.role !== "system") + .map((message, index) => (
- {message.content} +
+ {message.content} +
-
- ))} + )) + )} {isLoading && (
From a58d7a70ac65a0d3f2615bcde9d1aa901e4dc575 Mon Sep 17 00:00:00 2001 From: Shreya Shankar Date: Sat, 9 Nov 2024 12:48:30 -0800 Subject: [PATCH 3/6] fix: add types to UI assistant prompt --- website/src/components/AIChatPanel.tsx | 37 ++++++++++++++++++------ website/src/contexts/PipelineContext.tsx | 24 +++++++++++++-- 2 files changed, 50 insertions(+), 11 deletions(-) diff --git a/website/src/components/AIChatPanel.tsx b/website/src/components/AIChatPanel.tsx index 6fce6344..e4d80a98 100644 --- a/website/src/components/AIChatPanel.tsx +++ b/website/src/components/AIChatPanel.tsx @@ -100,17 +100,36 @@ const AIChatPanel: React.FC = ({ onClose }) => { { id: String(Date.now()), role: "system", - content: `You are the DocETL assistant, helping users build and refine data analysis pipelines. You are an expert at data analysis. DocETL enables users to create sophisticated data processing workflows that combine the power of LLMs with traditional data operations. + content: `You are the DocETL assistant, helping users build and refine data analysis pipelines. You are an expert at data analysis. -Each pipeline processes documents (or key-value pairs) through a sequence of operations. LLM operations like 'map' (process individual documents), 'reduce' (analyze multiple documents together), 'resolve' (entity resolution), and 'filter' (conditionally retain documents) can be combined with utility operations like 'unnest', 'split', 'gather', and 'sample' to create powerful analysis flows. -Every LLM operation has a prompt and an output schema, which determine the keys to be added to the documents as the operation runs. Prompts are jinja2 templates, and output schemas are JSON schemas. For 'map' and 'filter' operations, you can reference the current document as 'input' and access its keys as '{{ input.keyname }}'. For 'reduce' operations, you can reference the group of input documents as 'inputs' and loop over them with '{% for doc in inputs %}...{% endfor %}'. For 'resolve' operations, there are two prompts: the comparison_prompt compares two documents (referenced as '{{ input1 }}' and '{{ input2 }}') to determine if they match, while the resolution_prompt takes a group of matching documents (referenced as 'inputs') and canonicalizes them into a single output following the operation's schema. -You should help users optimize their pipelines and overcome any challenges they encounter. Ask questions to better understand their goals, suggest improvements, help debug issues, or explore new approaches to achieve their analysis objectives. Only ask one question at a time, and keep your responses concise. -Also, don't give lots of suggestions. Only give one or two at a time. For each suggestion, be very detailed--if you say "use the 'map' operation", then you should also say what the prompt and output schema should be. +Core Capabilities: +- DocETL enables users to create sophisticated data processing workflows combining LLMs with traditional data operations +- Each pipeline processes documents through a sequence of operations +- Operations can be LLM-based (map, reduce, resolve, filter) or utility-based (unnest, split, gather, sample) -You don't have the ability to write to the pipeline state. You can only give suggestions about what the user should do next. -Before jumping to a new operation, verify that the user is satisfied with the current operation's outputs. Ask them specific questions about the outputs related to the task, to get them to iterate on the current operation if needed. For example, if all the outputs look the same, maybe they want to be more specific in their operation prompt. The users will typically be biased towards accepting the current operation's outputs, so be very specific in your questions to get them to iterate on the current operation. You may have to propose iterations yourself. Always look for inadequacies in the outputs, and surface them to the user (with ) to see what they think. -Remember, always be specific! Never be vague or general. Never suggest new operations unless the user is completely satisfied with the current operation's outputs. -Your answers will be in markdown format. Use bold, italics, and lists to draw attention to important points. +Operation Details: +- Every LLM operation has: + - A prompt (jinja2 template) + - An output schema (JSON schema) +- Operation-specific templating: + - Map/Filter: Access current doc with '{{ input.keyname }}' + - Reduce: Loop through docs with '{% for doc in inputs %}...{% endfor %}' + - Resolve: Compare docs with '{{ input1 }}/{{ input2 }}' and canonicalize with 'inputs' + +Your Role: +- Help users optimize pipelines and overcome challenges +- Ask focused questions to understand goals (one at a time) +- Keep responses concise +- Provide 1-2 detailed suggestions at a time +- Cannot directly modify pipeline state - only provide guidance + +Best Practices: +- Verify satisfaction with current operation before suggesting new ones +- Ask specific questions about outputs to encourage iteration +- Look for and surface potential inadequacies in outputs +- Use markdown formatting (bold, italics, lists) for clarity. Only action items or suggestions should be bolded. +- Be specific, never vague or general +- Be concise, don't repeat yourself Here's their current pipeline state: ${pipelineState}`, diff --git a/website/src/contexts/PipelineContext.tsx b/website/src/contexts/PipelineContext.tsx index 06104c9d..35529005 100644 --- a/website/src/contexts/PipelineContext.tsx +++ b/website/src/contexts/PipelineContext.tsx @@ -84,6 +84,7 @@ const serializeState = async (state: PipelineState): Promise => { // Get important output samples let outputSample = ""; let currentOperationName = ""; + let schemaInfo = ""; if (state.output?.path) { try { @@ -106,6 +107,21 @@ const serializeState = async (state: PipelineState): Promise => { const importantColumns = operation?.output?.schema?.map((item) => item.key) || []; + // Generate schema information + if (outputs.length > 0) { + const firstRow = outputs[0]; + schemaInfo = Object.entries(firstRow) + .map(([key, value]) => { + const type = typeof value; + return `- ${key}: ${type}${ + importantColumns.includes(key) + ? " (output of current operation)" + : "" + }`; + }) + .join("\n"); + } + // Take up to 5 samples const samples = outputs .slice(0, 5) @@ -186,12 +202,16 @@ Input Dataset File: ${ Pipeline operations:${operationsDetails} -Bookmarks:${bookmarksDetails} +My notes:${bookmarks.length > 0 ? bookmarksDetails : "\nNo notes added yet"} ${ currentOperationName && outputSample ? ` Operation just executed: ${currentOperationName} -Sample output for current operation (the LLM-generated outputs for this operation are bolded; other keys are included but not bolded): + +Schema Information: +${schemaInfo} + +Sample output for current operation (the LLM-generated outputs for this operation are bolded; other keys from other operations or the original input file are included but not bolded): ${outputSample}` : "" }`; From 2b9aea83dcf88949871cc68688c712f18fa7757a Mon Sep 17 00:00:00 2001 From: Shreya Shankar Date: Mon, 11 Nov 2024 13:14:19 -0800 Subject: [PATCH 4/6] fix: allow reduce_key types to be lists --- docetl/operations/reduce.py | 10 +++- tests/basic/test_basic_reduce_resolve.py | 64 ++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 1 deletion(-) diff --git a/docetl/operations/reduce.py b/docetl/operations/reduce.py index e1ed4d32..2b6018c5 100644 --- a/docetl/operations/reduce.py +++ b/docetl/operations/reduce.py @@ -323,7 +323,15 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: else: # Group the input data by the reduce key(s) while maintaining original order def get_group_key(item): - return tuple(item[key] for key in reduce_keys) + key_values = [] + for key in reduce_keys: + value = item[key] + # Special handling for list-type values + if isinstance(value, list): + key_values.append(tuple(sorted(value))) # Convert list to sorted tuple + else: + key_values.append(value) + return tuple(key_values) grouped_data = {} for item in input_data: diff --git a/tests/basic/test_basic_reduce_resolve.py b/tests/basic/test_basic_reduce_resolve.py index 5326b5ab..26c26a8a 100644 --- a/tests/basic/test_basic_reduce_resolve.py +++ b/tests/basic/test_basic_reduce_resolve.py @@ -164,3 +164,67 @@ def test_reduce_operation_with_lineage( lineage = result[f"{reduce_config['name']}_lineage"] assert all(isinstance(item, dict) for item in lineage) assert all("name" in item and "email" in item for item in lineage) + + +def test_reduce_with_list_key(api_wrapper, default_model, max_threads): + """Test reduce operation with a list-type key""" + + # Test data with list-type classifications + input_data = [ + { + "content": "Document about AI and ML", + "classifications": ["AI", "ML"] + }, + { + "content": "Another ML and AI document", + "classifications": ["ML", "AI"] # Same classes but different order + }, + { + "content": "Document about AI only", + "classifications": ["AI"] + }, + { + "content": "Document about ML and Data", + "classifications": ["ML", "Data"] + } + ] + + # Configuration for reduce operation + config = { + "name": "test_reduce_list", + "type": "reduce", + "reduce_key": "classifications", + "prompt": """Combine the content of documents with the same classifications. + Input documents: {{ inputs }} + Please combine the content into a single summary.""", + "output": { + "schema": { + "combined_content": "string", + } + } + } + + # Create and execute reduce operation + operation = ReduceOperation(api_wrapper, config, default_model, max_threads) + results, _ = operation.execute(input_data) + + # Verify results + assert len(results) == 3 # Should have 3 groups: ["AI", "ML"], ["AI"], ["ML", "Data"] + + # Find the result with ["AI", "ML"] classifications + ai_ml_result = next(r for r in results if len(r["classifications"]) == 2 and "AI" in r["classifications"] and "ML" in r["classifications"]) + assert len(ai_ml_result["classifications"]) == 2 + assert set(ai_ml_result["classifications"]) == {"AI", "ML"} + + # Find the result with only ["AI"] classification + ai_result = next((r for r in results if r["classifications"] == ("AI",)), None) + if ai_result is None: + raise AssertionError("Could not find result with only ['AI'] classification") + assert ai_result["classifications"] == ("AI",) + + # Find the result with ["ML", "Data"] classifications + ml_data_result = next((r for r in results if set(r["classifications"]) == {"ML", "Data"}), None) + if ml_data_result is None: + raise AssertionError("Could not find result with ['ML', 'Data'] classification") + assert set(ml_data_result["classifications"]) == {"ML", "Data"} + From 7242da81f452cb4077acf494eed6543848b4ef18 Mon Sep 17 00:00:00 2001 From: Shreya Shankar Date: Mon, 11 Nov 2024 14:04:40 -0800 Subject: [PATCH 5/6] fix: allow user to pass in litellm completion kwargs --- docetl/operations/cluster.py | 1 + docetl/operations/equijoin.py | 4 ++++ docetl/operations/link_resolve.py | 1 + docetl/operations/map.py | 6 +++++- docetl/operations/reduce.py | 7 ++++++- docetl/operations/resolve.py | 5 +++++ docetl/operations/utils.py | 16 ++++++++++++---- tests/basic/test_basic_map.py | 32 +++++++++++++++++++++++++++++++ 8 files changed, 66 insertions(+), 6 deletions(-) diff --git a/docetl/operations/cluster.py b/docetl/operations/cluster.py index 033e3bf4..19413a9a 100644 --- a/docetl/operations/cluster.py +++ b/docetl/operations/cluster.py @@ -219,6 +219,7 @@ def validation_fn(response: Dict[str, Any]): else None ), verbose=self.config.get("verbose", False), + litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}), ) total_cost += response.total_cost if response.validated: diff --git a/docetl/operations/equijoin.py b/docetl/operations/equijoin.py index 884f0113..7ab99c0f 100644 --- a/docetl/operations/equijoin.py +++ b/docetl/operations/equijoin.py @@ -19,6 +19,8 @@ rich_as_completed, ) from docetl.utils import completion_cost +from pydantic import Field + # Global variables to store shared data _right_data = None @@ -66,6 +68,7 @@ class schema(BaseOperation.schema): limit_comparisons: Optional[int] = None blocking_keys: Optional[Dict[str, List[str]]] = None timeout: Optional[int] = None + litellm_completion_kwargs: Dict[str, Any] = Field(default_factory=dict) def compare_pair( self, @@ -101,6 +104,7 @@ def compare_pair( timeout_seconds=timeout_seconds, max_retries_per_timeout=max_retries_per_timeout, bypass_cache=self.config.get("bypass_cache", False), + litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}), ) output = self.runner.api.parse_llm_response( response.response, {"is_match": "bool"} diff --git a/docetl/operations/link_resolve.py b/docetl/operations/link_resolve.py index 2669ac66..e6a95c84 100644 --- a/docetl/operations/link_resolve.py +++ b/docetl/operations/link_resolve.py @@ -175,6 +175,7 @@ def validation_fn(response: Dict[str, Any]): else None ), verbose=self.config.get("verbose", False), + litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}), ) if response.validated: diff --git a/docetl/operations/map.py b/docetl/operations/map.py index 91c2916f..7b697c08 100644 --- a/docetl/operations/map.py +++ b/docetl/operations/map.py @@ -49,6 +49,7 @@ class schema(BaseOperation.schema): batch_size: Optional[int] = None clustering_method: Optional[str] = None batch_prompt: Optional[str] = None + litellm_completion_kwargs: Dict[str, Any] = Field(default_factory=dict) @field_validator("drop_keys") def validate_drop_keys(cls, v): if isinstance(v, str): @@ -213,6 +214,7 @@ def validation_fn(response: Union[Dict[str, Any], ModelResponse]): verbose=self.config.get("verbose", False), bypass_cache=self.config.get("bypass_cache", False), initial_result=initial_result, + litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}), ) if llm_result.validated: @@ -249,7 +251,8 @@ def _process_map_batch(items: List[Dict]) -> Tuple[List[Dict], float]: verbose=self.config.get("verbose", False), timeout_seconds=self.config.get("timeout", 120), max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), - bypass_cache=self.config.get("bypass_cache", False) + bypass_cache=self.config.get("bypass_cache", False), + litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}), ) total_cost += llm_result.total_cost @@ -460,6 +463,7 @@ def process_prompt(item, prompt_config): timeout_seconds=self.config.get("timeout", 120), max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), bypass_cache=self.config.get("bypass_cache", False), + litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}), ) output = self.runner.api.parse_llm_response( response.response, diff --git a/docetl/operations/reduce.py b/docetl/operations/reduce.py index 2b6018c5..f3fa99f6 100644 --- a/docetl/operations/reduce.py +++ b/docetl/operations/reduce.py @@ -25,6 +25,7 @@ ) from docetl.operations.utils import rich_as_completed from docetl.utils import completion_cost +from pydantic import Field class ReduceOperation(BaseOperation): @@ -51,7 +52,8 @@ class schema(BaseOperation.schema): value_sampling: Optional[Dict[str, Any]] = None verbose: Optional[bool] = None timeout: Optional[int] = None - + litellm_completion_kwargs: Dict[str, Any] = Field(default_factory=dict) + def __init__(self, *args, **kwargs): """ Initialize the ReduceOperation. @@ -797,6 +799,7 @@ def _increment_fold( ), bypass_cache=self.config.get("bypass_cache", False), verbose=self.config.get("verbose", False), + litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}), ) end_time = time.time() @@ -855,6 +858,7 @@ def _merge_results( ), bypass_cache=self.config.get("bypass_cache", False), verbose=self.config.get("verbose", False), + litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}), ) end_time = time.time() @@ -964,6 +968,7 @@ def _batch_reduce( ), gleaning_config=self.config.get("gleaning", None), verbose=self.config.get("verbose", False), + litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}), ) item_cost += response.total_cost diff --git a/docetl/operations/resolve.py b/docetl/operations/resolve.py index 0f896261..88dfbbfd 100644 --- a/docetl/operations/resolve.py +++ b/docetl/operations/resolve.py @@ -16,6 +16,8 @@ from docetl.operations.utils import RichLoopBar, rich_as_completed from docetl.utils import completion_cost, extract_jinja_variables +from pydantic import Field + def find_cluster(item, cluster_map): while item != cluster_map[item]: @@ -42,6 +44,7 @@ class schema(BaseOperation.schema): limit_comparisons: Optional[int] = None optimize: Optional[bool] = None timeout: Optional[int] = None + litellm_completion_kwargs: Dict[str, Any] = Field(default_factory=dict) def compare_pair( self, @@ -84,6 +87,7 @@ def compare_pair( timeout_seconds=timeout_seconds, max_retries_per_timeout=max_retries_per_timeout, bypass_cache=self.config.get("bypass_cache", False), + litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}), ) output = self.runner.api.parse_llm_response( response.response, @@ -541,6 +545,7 @@ def process_cluster(cluster): if self.config.get("validate", None) else None ), + litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}), ) reduction_cost = reduction_response.total_cost diff --git a/docetl/operations/utils.py b/docetl/operations/utils.py index a385c5ce..5dd60237 100644 --- a/docetl/operations/utils.py +++ b/docetl/operations/utils.py @@ -434,12 +434,13 @@ def call_llm_batch( timeout_seconds: int = 120, max_retries_per_timeout: int = 2, bypass_cache: bool = False, + litellm_completion_kwargs: Dict[str, Any] = {}, ) -> LLMResult: # Turn the output schema into a list of schemas output_schema = convert_dict_schema_to_list_schema(output_schema) # Invoke the LLM call - return self.call_llm(model, op_type,messages, output_schema, verbose=verbose, timeout_seconds=timeout_seconds, max_retries_per_timeout=max_retries_per_timeout, bypass_cache=bypass_cache) + return self.call_llm(model, op_type,messages, output_schema, verbose=verbose, timeout_seconds=timeout_seconds, max_retries_per_timeout=max_retries_per_timeout, bypass_cache=bypass_cache, litellm_completion_kwargs=litellm_completion_kwargs) def _cached_call_llm( @@ -456,6 +457,7 @@ def _cached_call_llm( verbose: bool = False, bypass_cache: bool = False, initial_result: Optional[Any] = None, + litellm_completion_kwargs: Dict[str, Any] = {}, ) -> LLMResult: """ Cached version of the call_llm function. @@ -489,7 +491,7 @@ def _cached_call_llm( else: if not initial_result: response = self._call_llm_with_cache( - model, op_type, messages, output_schema, tools, scratchpad + model, op_type, messages, output_schema, tools, scratchpad, litellm_completion_kwargs ) total_cost += completion_cost(response) else: @@ -556,6 +558,7 @@ def _cached_call_llm( } ], tool_choice="required", + **litellm_completion_kwargs, ) total_cost += completion_cost(validator_response) @@ -583,7 +586,7 @@ def _cached_call_llm( # Call LLM again response = self._call_llm_with_cache( - model, op_type, messages, output_schema, tools, scratchpad + model, op_type, messages, output_schema, tools, scratchpad, litellm_completion_kwargs ) parsed_output = self.parse_llm_response( response, output_schema, tools @@ -633,7 +636,7 @@ def _cached_call_llm( i += 1 response = self._call_llm_with_cache( - model, op_type, messages, output_schema, tools, scratchpad + model, op_type, messages, output_schema, tools, scratchpad, litellm_completion_kwargs ) total_cost += completion_cost(response) @@ -662,6 +665,7 @@ def call_llm( verbose: bool = False, bypass_cache: bool = False, initial_result: Optional[Any] = None, + litellm_completion_kwargs: Dict[str, Any] = {}, ) -> LLMResult: """ Wrapper function that uses caching for LLM calls. @@ -706,6 +710,7 @@ def call_llm( verbose=verbose, bypass_cache=bypass_cache, initial_result=initial_result, + litellm_completion_kwargs=litellm_completion_kwargs, ) except RateLimitError: # TODO: this is a really hacky way to handle rate limits @@ -735,6 +740,7 @@ def _call_llm_with_cache( output_schema: Dict[str, str], tools: Optional[str] = None, scratchpad: Optional[str] = None, + litellm_completion_kwargs: Dict[str, Any] = {}, ) -> Any: """ Make an LLM call with caching. @@ -841,6 +847,7 @@ def _call_llm_with_cache( + messages, tools=tools, tool_choice=tool_choice, + **litellm_completion_kwargs, ) else: response = completion( @@ -852,6 +859,7 @@ def _call_llm_with_cache( }, ] + messages, + **litellm_completion_kwargs, ) diff --git a/tests/basic/test_basic_map.py b/tests/basic/test_basic_map.py index 43fc3e21..a1919f6f 100644 --- a/tests/basic/test_basic_map.py +++ b/tests/basic/test_basic_map.py @@ -289,3 +289,35 @@ def test_map_operation_with_larger_batch(simple_map_config, map_sample_data_with assert all( any(vs in result["sentiment"] for vs in valid_sentiments) for result in results ) + +def test_map_operation_with_max_tokens(simple_map_config, map_sample_data, api_wrapper): + # Add litellm_completion_kwargs configuration with max_tokens + map_config_with_max_tokens = { + **simple_map_config, + "litellm_completion_kwargs": { + "max_tokens": 10 + }, + "bypass_cache": True + } + + operation = MapOperation(api_wrapper, map_config_with_max_tokens, "gpt-4o-mini", 4) + + # Execute the operation + results, cost = operation.execute(map_sample_data) + + # Assert that we have results for all input items + assert len(results) == len(map_sample_data) + + # Check that all results have a sentiment + assert all("sentiment" in result for result in results) + + # Verify that all sentiments are valid + valid_sentiments = ["positive", "negative", "neutral"] + assert all( + any(vs in result["sentiment"] for vs in valid_sentiments) for result in results + ) + + # Since we limited max_tokens to 10, each response should be relatively short + # The sentiment field should contain just the sentiment value without much extra text + assert all(len(result["sentiment"]) <= 20 for result in results) + From 21cece429c3786b22e7618efed7c04cfe85c4e9b Mon Sep 17 00:00:00 2001 From: Shreya Shankar Date: Mon, 11 Nov 2024 14:09:06 -0800 Subject: [PATCH 6/6] fix: allow user to pass in litellm completion kwargs --- docs/concepts/operators.md | 9 ++++++++- docs/operators/cluster.md | 1 + docs/operators/map.md | 1 + docs/operators/parallel-map.md | 1 + docs/operators/reduce.md | 1 + docs/operators/resolve.md | 4 +++- 6 files changed, 15 insertions(+), 2 deletions(-) diff --git a/docs/concepts/operators.md b/docs/concepts/operators.md index 617e4499..126af9d9 100644 --- a/docs/concepts/operators.md +++ b/docs/concepts/operators.md @@ -24,13 +24,20 @@ LLM-based operators have additional attributes: - `prompt`: A Jinja2 template that defines the instruction for the language model. - `output`: Specifies the schema for the output from the LLM call. - `model` (optional): Allows specifying a different model from the pipeline default. +- `litellm_completion_kwargs` (optional): Additional parameters to pass to LiteLLM completion calls. + +DocETL uses [LiteLLM](https://docs.litellm.ai) to execute all LLM calls, providing support for 100+ LLM providers including OpenAI, Anthropic, Azure, and more. You can pass any LiteLLM completion arguments using the `litellm_completion_kwargs` field. Example: ```yaml - name: extract_insights type: map - model: gpt-4o + model: gpt-4o-mini + litellm_completion_kwargs: + max_tokens: 500 # limit response length + temperature: 0.7 # control randomness + top_p: 0.9 # nucleus sampling parameter prompt: | Analyze the following user interaction log: {{ input.log }} diff --git a/docs/operators/cluster.md b/docs/operators/cluster.md index e7abb65a..4e4f0d3f 100644 --- a/docs/operators/cluster.md +++ b/docs/operators/cluster.md @@ -184,3 +184,4 @@ and a description, and groups them into a tree of categories. | `timeout` | Timeout for each LLM call in seconds | 120 | | `max_retries_per_timeout` | Maximum number of retries per timeout | 2 | | `sample` | Number of items to sample for this operation | None | +| `litellm_completion_kwargs` | Additional parameters to pass to LiteLLM completion calls. | {} | diff --git a/docs/operators/map.md b/docs/operators/map.md index bbc48727..10f02209 100644 --- a/docs/operators/map.md +++ b/docs/operators/map.md @@ -147,6 +147,7 @@ This example demonstrates how the Map operation can transform long, unstructured | `timeout` | Timeout for each LLM call in seconds | 120 | | `max_retries_per_timeout` | Maximum number of retries per timeout | 2 | | `timeout` | Timeout for each LLM call in seconds | 120 | +| `litellm_completion_kwargs` | Additional parameters to pass to LiteLLM completion calls. | {} | Note: If `drop_keys` is specified, `prompt` and `output` become optional parameters. diff --git a/docs/operators/parallel-map.md b/docs/operators/parallel-map.md index 85ae7682..01afba9d 100644 --- a/docs/operators/parallel-map.md +++ b/docs/operators/parallel-map.md @@ -37,6 +37,7 @@ Each prompt configuration in the `prompts` list should contain: | `sample` | Number of samples to use for the operation | Processes all data | | `timeout` | Timeout for each LLM call in seconds | 120 | | `max_retries_per_timeout` | Maximum number of retries per timeout | 2 | +| `litellm_completion_kwargs` | Additional parameters to pass to LiteLLM completion calls. | {} | ??? question "Why use Parallel Map instead of multiple Map operations?" diff --git a/docs/operators/reduce.md b/docs/operators/reduce.md index 3c01fd82..220789b9 100644 --- a/docs/operators/reduce.md +++ b/docs/operators/reduce.md @@ -64,6 +64,7 @@ This Reduce operation processes customer feedback grouped by department: | `persist_intermediates` | If true, persists the intermediate results for each group to the key `_{operation_name}_intermediates` | false | | `timeout` | Timeout for each LLM call in seconds | 120 | | `max_retries_per_timeout` | Maximum number of retries per timeout | 2 | +| `litellm_completion_kwargs` | Additional parameters to pass to LiteLLM completion calls. | {} | ## Advanced Features diff --git a/docs/operators/resolve.md b/docs/operators/resolve.md index 72f01065..eeaf55fa 100644 --- a/docs/operators/resolve.md +++ b/docs/operators/resolve.md @@ -126,7 +126,9 @@ After determining eligible pairs for comparison, the Resolve operation uses a Un | `limit_comparisons` | Maximum number of comparisons to perform | None | | `timeout` | Timeout for each LLM call in seconds | 120 | | `max_retries_per_timeout` | Maximum number of retries per timeout | 2 | -| `sample` | Number of samples to use for the operation | None | +| `sample` | Number of samples to use for the operation | None | +| `litellm_completion_kwargs` | Additional parameters to pass to LiteLLM completion calls. | {} | + ## Best Practices 1. **Anticipate Resolve Needs**: If you anticipate needing a Resolve operation and want to control the prompts, create it in your pipeline and let the optimizer find the appropriate blocking rules and thresholds.