From 285e63f674f5049be6ec60d7e2f204baec427264 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 17 Sep 2024 15:06:28 +0200 Subject: [PATCH] wip --- .github/actions/spelling/allow.txt | 2 + .../README.md | 18 ++-- .../app/README.md | 4 +- .../app/chain.py | 21 ++-- .../app/eval/data/chats.yaml | 2 +- .../app/eval/utils.py | 31 +++--- .../app/patterns/custom_rag_qa/chain.py | 23 ++-- .../app/patterns/custom_rag_qa/templates.py | 31 ++++-- .../patterns/custom_rag_qa/vector_store.py | 4 +- .../patterns/langgraph_dummy_agent/chain.py | 23 ++-- .../app/server.py | 53 +++++----- .../app/utils/input_types.py | 9 +- .../app/utils/output_types.py | 18 +--- .../app/utils/tracing.py | 38 +++---- .../deployment/README.md | 2 +- .../deployment/terraform/artifact_registry.tf | 2 +- .../terraform/dev/service_accounts.tf | 2 +- .../deployment/terraform/service_accounts.tf | 2 +- .../notebooks/getting_started.ipynb | 56 ++++++---- .../poetry.lock | 2 +- .../streamlit/side_bar.py | 82 ++++++++------ .../streamlit/streamlit_app.py | 100 ++++++++---------- .../streamlit/style/app_markdown.py | 2 +- .../streamlit/utils/local_chat_history.py | 30 +++--- .../streamlit/utils/message_editing.py | 34 ++++-- .../streamlit/utils/multimodal_utils.py | 78 ++++++++------ .../streamlit/utils/stream_handler.py | 72 ++++++------- .../streamlit/utils/title_summary.py | 2 +- .../streamlit/utils/utils.py | 4 +- .../patterns/test_langgraph_dummy_agent.py | 21 ++-- .../tests/integration/patterns/test_rag_qa.py | 23 ++-- .../tests/integration/test_chain.py | 23 ++-- .../tests/integration/test_server_e2e.py | 42 +++++--- .../tests/load_test/load_test.py | 12 +-- .../tests/unit/test_server.py | 22 ++-- .../unit/test_utils/test_tracing_exporter.py | 47 ++++---- 36 files changed, 517 insertions(+), 420 deletions(-) diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index b6b729c1f75..016f74fcf7a 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -118,6 +118,7 @@ Khanh Knopf Kubeflow Kudrow +Langraph LCEL LLMs LOOKBACK @@ -286,6 +287,7 @@ codebase codebases codelab codelabs +codespell colab coldline coloraxis diff --git a/gemini/sample-apps/conversational-genai-app-template/README.md b/gemini/sample-apps/conversational-genai-app-template/README.md index d9a653697c6..cbb60320418 100644 --- a/gemini/sample-apps/conversational-genai-app-template/README.md +++ b/gemini/sample-apps/conversational-genai-app-template/README.md @@ -1,15 +1,15 @@ -# πŸš€ Conversational GenAI App Template! πŸš€ +# πŸš€ Conversational Generative AI App Template! πŸš€ >**Focus on Innovation, not Infrastructure** This folder is a **starter-pack** to building a Generative AI application on Google Cloud Platform (GCP). -It is meant to be a template repository to build your own GenAI application. +It is meant to be a template repository to build your own Generative AI application. We provide a comprehensive set of resources to guide you through the entire development process, from prototype to production. ## High-Level Architecture -This template covers all aspects of GenAI app development, from prototyping and evaluation to deployment and monitoring. +This template covers all aspects of Generative AI app development, from prototyping and evaluation to deployment and monitoring. ![High Level Architecture](images/high_level_architecture.png "Architecture") @@ -39,7 +39,7 @@ This template covers all aspects of GenAI app development, from prototyping and | Description | Visualization | |-------------|---------------| -| The repository showcases how to evaluate GenAI applications using tools like Vertex AI rapid eval SDK and Vertex AI Experiments. | ![Vertex AI Rapid Eval](images/vertex_ai_rapid_eval.png) | +| The repository showcases how to evaluate Generative AI applications using tools like Vertex AI rapid eval SDK and Vertex AI Experiments. | ![Vertex AI Rapid Eval](images/vertex_ai_rapid_eval.png) | @@ -60,7 +60,7 @@ This template covers all aspects of GenAI app development, from prototyping and | Description | Visualization | |-------------|---------------| -| Monitor your GenAI Application's performance. We provide a Looker Studio [dashboard](https://lookerstudio.google.com/u/0/reporting/fa742264-4b4b-4c56-81e6-a667dd0f853f) to monitor application conversation statistics and user feedback. | ![Dashboard1](images/dashboard_1.png) | +| Monitor your Generative AI Application's performance. We provide a Looker Studio [dashboard](https://lookerstudio.google.com/u/0/reporting/fa742264-4b4b-4c56-81e6-a667dd0f853f) to monitor application conversation statistics and user feedback. | ![Dashboard1](images/dashboard_1.png) | | We can also drill down to individual conversations and view the messages exchanged | ![Dashboard2](images/dashboard_2.png) | @@ -79,7 +79,7 @@ This template covers all aspects of GenAI app development, from prototyping and | Description | Visualization | |-------------|---------------| -| Experiment with your GenAI Application in a feature-rich playground, including chat curation, user feedback collection, multimodal input, and more! | ![Streamlit View](images/streamlit_view.png) | +| Experiment with your Generative AI Application in a feature-rich playground, including chat curation, user feedback collection, multimodal input, and more! | ![Streamlit View](images/streamlit_view.png) | @@ -97,7 +97,7 @@ This template covers all aspects of GenAI app development, from prototyping and ```bash gsutil -m cp -r gs://genai-starter-pack-templates/genai-starter-pack-template . ``` -Use the downloaded folder as a starting point for your own GenAI application. +Use the downloaded folder as a starting point for your own Generative AI application. ### Installation @@ -131,11 +131,11 @@ For full command options and usage, refer to the [Makefile](Makefile). ## Usage -1. **Prototype Your Chain:** Build your GenAI Application using different methodologies and frameworks. Use Vertex AI Evaluation for assessing the performances of your application and relative chain of steps. **See [`notebooks/getting_started.ipynb`](notebooks/getting_started.ipynb) for a tutorial to get started building and evaluating your chain.** +1. **Prototype Your Chain:** Build your Generative AI Application using different methodologies and frameworks. Use Vertex AI Evaluation for assessing the performances of your application and relative chain of steps. **See [`notebooks/getting_started.ipynb`](notebooks/getting_started.ipynb) for a tutorial to get started building and evaluating your chain.** 2. **Integrate into the App:** Import your chain into the app. Edit `app/chain.py` file to add your chain. 3. **Playground Testing:** Explore your chain's functionality using the Streamlit playground. Take advantage of the comprehensive playground features, such as chat history management, user feedback mechanisms, support for various input types, and additional capabilities. You can run the playground locally with `make playground` command. 4. **Deploy with CI/CD:** Configure and trigger the CI/CD pipelines. Edit tests if needed. See the [deployment section](#deployment) below for more details. -5. **Monitor in Production:** Track performance and gather insights using Cloud Logging, Tracing, and the Looker Studio dashboard. Use the gathered data to iterate on your GenAI application. +5. **Monitor in Production:** Track performance and gather insights using Cloud Logging, Tracing, and the Looker Studio dashboard. Use the gathered data to iterate on your Generative AI application. ## Deployment diff --git a/gemini/sample-apps/conversational-genai-app-template/app/README.md b/gemini/sample-apps/conversational-genai-app-template/app/README.md index c006a85f36e..bafc371c3b1 100644 --- a/gemini/sample-apps/conversational-genai-app-template/app/README.md +++ b/gemini/sample-apps/conversational-genai-app-template/app/README.md @@ -14,7 +14,7 @@ This folder implements a chatbot application using FastAPI, and Google Cloud ser β”œβ”€β”€ utils/ # Utility functions and classes └── eval/ # Evaluation tools and data ``` -## GenAI Application Patterns +## Generative AI Application Patterns ### 1. Default Chain @@ -45,7 +45,7 @@ All chains have the same interface, allowing for seamless swapping without chang ### Trace and Log Capture -This application utilizes [OpenTelemetry](https://opentelemetry.io/) and [OpenLLMetry](https://github.com/traceloop/openllmetry) for comprehensive observability, emitting events to Google Cloud Trace and Google Cloud Logging. Every interaction with Langchain and VertexAI is instrumented (see [`server.py`](server.py)), enabling detailed tracing of request flows throughout the application. +This application utilizes [OpenTelemetry](https://opentelemetry.io/) and [OpenLLMetry](https://github.com/traceloop/openllmetry) for comprehensive observability, emitting events to Google Cloud Trace and Google Cloud Logging. Every interaction with LangChain and VertexAI is instrumented (see [`server.py`](server.py)), enabling detailed tracing of request flows throughout the application. Leveraging the [CloudTraceSpanExporter](https://cloud.google.com/python/docs/reference/spanner/latest/opentelemetry-tracing), the application captures and exports tracing data. To address the limitations of Cloud Trace ([256-byte attribute value limit](https://cloud.google.com/trace/docs/quotas#limits_on_spans)) and [Cloud Logging](https://cloud.google.com/logging/quotas) ([256KB log entry size](https://cloud.google.com/logging/quotas)), a custom extension of the CloudTraceSpanExporter is implemented in [`app/utils/tracing.py`](app/utils/tracing.py). diff --git a/gemini/sample-apps/conversational-genai-app-template/app/chain.py b/gemini/sample-apps/conversational-genai-app-template/app/chain.py index c126d460413..497ec5c88d5 100644 --- a/gemini/sample-apps/conversational-genai-app-template/app/chain.py +++ b/gemini/sample-apps/conversational-genai-app-template/app/chain.py @@ -15,25 +15,22 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_google_vertexai import ChatVertexAI, HarmBlockThreshold, HarmCategory -safety_settings = { - HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_ONLY_HIGH, - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH, -} llm = ChatVertexAI( - model_name="gemini-1.5-flash-001", temperature=0, max_output_tokens=1024, - safety_settings=safety_settings + model_name="gemini-1.5-flash-001", + temperature=0, + max_output_tokens=1024, ) template = ChatPromptTemplate.from_messages( [ - ("system", """You are a conversational bot that produce recipes for users based - on a question."""), - MessagesPlaceholder(variable_name="messages") + ( + "system", + """You are a conversational bot that produce recipes for users based + on a question.""", + ), + MessagesPlaceholder(variable_name="messages"), ] ) diff --git a/gemini/sample-apps/conversational-genai-app-template/app/eval/data/chats.yaml b/gemini/sample-apps/conversational-genai-app-template/app/eval/data/chats.yaml index bb950ac2bf1..3375e6c22d4 100644 --- a/gemini/sample-apps/conversational-genai-app-template/app/eval/data/chats.yaml +++ b/gemini/sample-apps/conversational-genai-app-template/app/eval/data/chats.yaml @@ -39,4 +39,4 @@ - type: human content: Those all sound great! I like the Burnt aubergine veggie chilli - type: ai - content: That's a great choice! I hope you enjoy it. \ No newline at end of file + content: That's a great choice! I hope you enjoy it. \ No newline at end of file diff --git a/gemini/sample-apps/conversational-genai-app-template/app/eval/utils.py b/gemini/sample-apps/conversational-genai-app-template/app/eval/utils.py index 9892381ef63..df070cc65ec 100644 --- a/gemini/sample-apps/conversational-genai-app-template/app/eval/utils.py +++ b/gemini/sample-apps/conversational-genai-app-template/app/eval/utils.py @@ -12,20 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +from concurrent.futures import ThreadPoolExecutor +from functools import partial import glob import logging import os -from concurrent.futures import ThreadPoolExecutor -from functools import partial from typing import Any, Callable, Dict, Iterator, List import nest_asyncio import pandas as pd -import yaml from tqdm import tqdm +import yaml nest_asyncio.apply() + def load_chats(path: str) -> List[Dict[str, Any]]: """ Loads a list of chats from a directory or file. @@ -44,6 +45,7 @@ def load_chats(path: str) -> List[Dict[str, Any]]: chats = chats + chats_in_file return chats + def pairwise(iterable: List[Any]) -> Iterator[tuple[Any, Any]]: """Creates an iterable with tuples paired together e.g s -> (s0, s1), (s2, s3), (s4, s5), ... @@ -81,11 +83,9 @@ def generate_multiturn_history(df: pd.DataFrame) -> pd.DataFrame: message = { "human_message": human_message, "ai_message": ai_message, - "conversation_history": conversation_history + "conversation_history": conversation_history, } - conversation_history = conversation_history + [ - human_message, ai_message - ] + conversation_history = conversation_history + [human_message, ai_message] processed_messages.append(message) return pd.DataFrame(processed_messages) @@ -103,7 +103,7 @@ def generate_message(row: tuple[int, Dict[str, Any]], callable: Any) -> Dict[str Args: row (tuple[int, Dict[str, Any]]): A tuple containing the index and a dictionary with message data, including: - - "conversation_history" (List[str]): Optional. List of previous + - "conversation_history" (List[str]): Optional. List of previous messages in the conversation. - "human_message" (str): The current human message. @@ -118,7 +118,9 @@ def generate_message(row: tuple[int, Dict[str, Any]], callable: Any) -> Dict[str - "response_obj" (Any): The usage metadata of the response from the callable. """ index, message = row - messages = message["conversation_history"] if "conversation_history" in message else [] + messages = ( + message["conversation_history"] if "conversation_history" in message else [] + ) messages.append(message["human_message"]) input_callable = {"messages": messages} response = callable.invoke(input_callable) @@ -130,7 +132,7 @@ def generate_message(row: tuple[int, Dict[str, Any]], callable: Any) -> Dict[str def batch_generate_messages( messages: pd.DataFrame, callable: Callable[[List[Dict[str, Any]]], Dict[str, Any]], - max_workers: int = 4 + max_workers: int = 4, ) -> pd.DataFrame: """Generates AI responses to user messages using a provided callable. @@ -152,8 +154,8 @@ def batch_generate_messages( ] ``` - callable (Callable[[List[Dict[str, Any]]], Dict[str, Any]]): Callable object - (e.g., Langchain Chain) used + callable (Callable[[List[Dict[str, Any]]], Dict[str, Any]]): Callable object + (e.g., LangChain Chain) used for response generation. It should accept a list of message dictionaries (as described above) and return a dictionary with the following structure: @@ -202,6 +204,7 @@ def batch_generate_messages( predicted_messages.append(message) return pd.DataFrame(predicted_messages) + def save_df_to_csv(df: pd.DataFrame, dir_path: str, filename: str) -> None: """Saves a pandas DataFrame to directory as a CSV file. @@ -233,7 +236,9 @@ def prepare_metrics(metrics: List[str]) -> List[Any]: *module_path, metric_name = metric.removeprefix("custom:").split(".") metrics_evaluation.append( __import__(".".join(module_path), fromlist=[metric_name]).__dict__[ - metric_name]) + metric_name + ] + ) else: metrics_evaluation.append(metric) return metrics_evaluation diff --git a/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/chain.py b/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/chain.py index 7ef2475287e..4ee8d39ea50 100644 --- a/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/chain.py +++ b/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/chain.py @@ -16,14 +16,13 @@ import logging from typing import Any, Dict, Iterator -import google -import vertexai -from langchain_google_community.vertex_rank import VertexAIRank -from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings - from app.patterns.custom_rag_qa.templates import query_rewrite_template, rag_template from app.patterns.custom_rag_qa.vector_store import get_vector_store from app.utils.output_types import OnChatModelStreamEvent, OnToolEndEvent, custom_chain +import google +from langchain_google_community.vertex_rank import VertexAIRank +from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings +import vertexai # Configuration EMBEDDING_MODEL = "text-embedding-004" @@ -38,7 +37,7 @@ embedding = VertexAIEmbeddings(model_name=EMBEDDING_MODEL) vector_store = get_vector_store(embedding=embedding) retriever = vector_store.as_retriever(search_kwargs={"k": 20}) -reranker = VertexAIRank( +compressor = VertexAIRank( project_id=project_id, location_id="global", ranking_config="default_ranking_config", @@ -52,9 +51,11 @@ @custom_chain -def chain(input: Dict[str, Any], **kwargs: Any) -> Iterator[OnToolEndEvent | OnChatModelStreamEvent]: +def chain( + input: Dict[str, Any], **kwargs: Any +) -> Iterator[OnToolEndEvent | OnChatModelStreamEvent]: """ - Implements a RAG QA chain. Decorated with `custom_chain` to offer Langchain compatible astream_events + Implements a RAG QA chain. Decorated with `custom_chain` to offer LangChain compatible astream_events and invoke interface and OpenTelemetry tracing. """ # Generate optimized query @@ -62,13 +63,13 @@ def chain(input: Dict[str, Any], **kwargs: Any) -> Iterator[OnToolEndEvent | OnC # Retrieve and rank documents retrieved_docs = retriever.invoke(query) - ranked_docs = reranker.compress_documents(documents=retrieved_docs, query=query) + ranked_docs = compressor.compress_documents(documents=retrieved_docs, query=query) # Yield tool results metadata yield OnToolEndEvent(data={"input": {"query": query}, "output": ranked_docs}) # Stream LLM response for chunk in response_chain.stream( - input={"messages": input["messages"], "relevant_documents": ranked_docs} + input={"messages": input["messages"], "relevant_documents": ranked_docs} ): - yield OnChatModelStreamEvent(data={"chunk": chunk}) \ No newline at end of file + yield OnChatModelStreamEvent(data={"chunk": chunk}) diff --git a/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/templates.py b/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/templates.py index bd2a7482c65..d91b9fb76d1 100644 --- a/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/templates.py +++ b/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/templates.py @@ -14,14 +14,22 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -query_rewrite_template = ChatPromptTemplate.from_messages([ - ("system", "Rewrite a query to a semantic search engine using the current conversation. " - "Provide only the rewritten query as output."), - MessagesPlaceholder(variable_name="messages") -]) +query_rewrite_template = ChatPromptTemplate.from_messages( + [ + ( + "system", + "Rewrite a query to a semantic search engine using the current conversation. " + "Provide only the rewritten query as output.", + ), + MessagesPlaceholder(variable_name="messages"), + ] +) -rag_template = ChatPromptTemplate.from_messages([ - ("system", """You are an AI assistant for question-answering tasks. Follow these guidelines: +rag_template = ChatPromptTemplate.from_messages( + [ + ( + "system", + """You are an AI assistant for question-answering tasks. Follow these guidelines: 1. Use only the provided context to answer the question. 2. Give clear, accurate responses based on the information available. 3. If the context is insufficient, state: "I don't have enough information to answer this question." @@ -39,6 +47,9 @@ {{ doc.page_content | safe }} {% endfor %} -"""), - MessagesPlaceholder(variable_name="messages") -], template_format="jinja2") \ No newline at end of file +""", + ), + MessagesPlaceholder(variable_name="messages"), + ], + template_format="jinja2", +) diff --git a/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/vector_store.py b/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/vector_store.py index 1738ed8cba7..017d1383a10 100644 --- a/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/vector_store.py +++ b/gemini/sample-apps/conversational-genai-app-template/app/patterns/custom_rag_qa/vector_store.py @@ -40,8 +40,8 @@ def load_and_split_documents(url: str) -> List[Document]: def get_vector_store( - embedding: Embeddings, persist_path: str = PERSIST_PATH, url: str = URL - ) -> SKLearnVectorStore: + embedding: Embeddings, persist_path: str = PERSIST_PATH, url: str = URL +) -> SKLearnVectorStore: """Get or create a vector store.""" vector_store = SKLearnVectorStore(embedding=embedding, persist_path=persist_path) diff --git a/gemini/sample-apps/conversational-genai-app-template/app/patterns/langgraph_dummy_agent/chain.py b/gemini/sample-apps/conversational-genai-app-template/app/patterns/langgraph_dummy_agent/chain.py index 507b65f4c00..fced6df507a 100644 --- a/gemini/sample-apps/conversational-genai-app-template/app/patterns/langgraph_dummy_agent/chain.py +++ b/gemini/sample-apps/conversational-genai-app-template/app/patterns/langgraph_dummy_agent/chain.py @@ -30,27 +30,30 @@ def search(query: str) -> str: return "It's 60 degrees and foggy." return "It's 90 degrees and sunny." + tools = [search] # 2. Set up the language model llm = ChatVertexAI( - model="gemini-1.5-pro-001", - temperature=0, - max_tokens=1024, - streaming=True + model="gemini-1.5-pro-001", temperature=0, max_tokens=1024, streaming=True ).bind_tools(tools) + # 3. Define workflow components def should_continue(state: MessagesState) -> str: """Determines whether to use tools or end the conversation.""" - last_message = state['messages'][-1] - return "tools" if last_message.tool_calls else END # type: ignore[union-attr] + last_message = state["messages"][-1] + return "tools" if last_message.tool_calls else END # type: ignore[union-attr] -async def call_model(state: MessagesState, config: RunnableConfig) -> Dict[str, BaseMessage]: + +async def call_model( + state: MessagesState, config: RunnableConfig +) -> Dict[str, BaseMessage]: """Calls the language model and returns the response.""" - response = llm.invoke(state['messages'], config) + response = llm.invoke(state["messages"], config) return {"messages": response} + # 4. Create the workflow graph workflow = StateGraph(MessagesState) workflow.add_node("agent", call_model) @@ -59,7 +62,7 @@ async def call_model(state: MessagesState, config: RunnableConfig) -> Dict[str, # 5. Define graph edges workflow.add_conditional_edges("agent", should_continue) -workflow.add_edge("tools", 'agent') +workflow.add_edge("tools", "agent") # 6. Compile the workflow -chain = workflow.compile() \ No newline at end of file +chain = workflow.compile() diff --git a/gemini/sample-apps/conversational-genai-app-template/app/server.py b/gemini/sample-apps/conversational-genai-app-template/app/server.py index 63ce525fda1..19e1b004580 100644 --- a/gemini/sample-apps/conversational-genai-app-template/app/server.py +++ b/gemini/sample-apps/conversational-genai-app-template/app/server.py @@ -15,22 +15,20 @@ import json import logging import os -import uuid from typing import AsyncGenerator +import uuid +# ruff: noqa: I001 +## Import the chain to be used +from app.chain import chain +from app.utils.input_types import Feedback, Input, InputChat, default_serialization +from app.utils.output_types import EndEvent, Event +from app.utils.tracing import CloudTraceLoggingSpanExporter from fastapi import FastAPI from fastapi.responses import RedirectResponse, StreamingResponse from google.cloud import logging as gcp_logging from traceloop.sdk import Instruments, Traceloop -from app.utils.input_types import Feedback, Input, InputChat, default_serialization -from app.utils.output_types import EndEvent, Event -from app.utils.tracing import CloudTraceLoggingSpanExporter - -# ruff: noqa: I001 -## Import the chain to be used -from app.chain import chain - # Or choose one of the following pattern chains to test by uncommenting it: # Custom RAG QA @@ -40,8 +38,13 @@ # from app.patterns.langgraph_dummy_agent.chain import chain # The events that are supported by the UI Frontend -SUPPORTED_EVENTS = ["on_tool_start", "on_tool_end", "on_retriever_start", - "on_retriever_end", "on_chat_model_stream"] +SUPPORTED_EVENTS = [ + "on_tool_start", + "on_tool_end", + "on_retriever_start", + "on_retriever_end", + "on_chat_model_stream", +] # Initialize FastAPI app and logging app = FastAPI() @@ -63,19 +66,20 @@ async def stream_event_response(input_chat: InputChat) -> AsyncGenerator[str, None]: run_id = uuid.uuid4() input_dict = input_chat.model_dump() - Traceloop.set_association_properties({ - "log_type": "tracing", - "run_id": str(run_id), - "user_id": input_dict["user_id"], - "session_id": input_dict["session_id"], - "commit_sha": os.environ.get("COMMIT_SHA", "None") - }) + Traceloop.set_association_properties( + { + "log_type": "tracing", + "run_id": str(run_id), + "user_id": input_dict["user_id"], + "session_id": input_dict["session_id"], + "commit_sha": os.environ.get("COMMIT_SHA", "None"), + } + ) yield json.dumps( - Event( - event="metadata", - data={"run_id": str(run_id)} - ), default=default_serialization) + "\n" + Event(event="metadata", data={"run_id": str(run_id)}), + default=default_serialization, + ) + "\n" async for data in chain.astream_events(input_dict, version="v2"): if data["event"] in SUPPORTED_EVENTS: @@ -97,8 +101,9 @@ async def collect_feedback(feedback_dict: Feedback) -> None: @app.post("/stream_events") async def stream_chat_events(request: Input) -> StreamingResponse: - return StreamingResponse(stream_event_response(input_chat=request.input), - media_type="text/event-stream") + return StreamingResponse( + stream_event_response(input_chat=request.input), media_type="text/event-stream" + ) # Main execution diff --git a/gemini/sample-apps/conversational-genai-app-template/app/utils/input_types.py b/gemini/sample-apps/conversational-genai-app-template/app/utils/input_types.py index 0434fb09e92..b314474a6b3 100644 --- a/gemini/sample-apps/conversational-genai-app-template/app/utils/input_types.py +++ b/gemini/sample-apps/conversational-genai-app-template/app/utils/input_types.py @@ -20,9 +20,9 @@ class InputChat(BaseModel): """Represents the input for a chat session.""" + messages: List[Union[HumanMessage, AIMessage]] = Field( - ..., - description="The chat messages representing the current conversation." + ..., description="The chat messages representing the current conversation." ) user_id: str = "" session_id: str = "" @@ -30,16 +30,17 @@ class InputChat(BaseModel): class Input(BaseModel): """Wrapper class for InputChat.""" + input: InputChat class Feedback(BaseModel): """Represents feedback for a conversation.""" + score: Union[int, float] text: Optional[str] = None run_id: str - log_type: Literal['feedback'] = 'feedback' - + log_type: Literal["feedback"] = "feedback" def default_serialization(obj: Any) -> Any: diff --git a/gemini/sample-apps/conversational-genai-app-template/app/utils/output_types.py b/gemini/sample-apps/conversational-genai-app-template/app/utils/output_types.py index 21b26914007..1d3bf6066a7 100644 --- a/gemini/sample-apps/conversational-genai-app-template/app/utils/output_types.py +++ b/gemini/sample-apps/conversational-genai-app-template/app/utils/output_types.py @@ -12,18 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import uuid from functools import wraps from types import GeneratorType -from typing import ( - Any, - AsyncGenerator, - Callable, - Dict, - List, - Literal, - Union, -) +from typing import Any, AsyncGenerator, Callable, Dict, List, Literal, Union +import uuid from langchain_core.documents import Document from langchain_core.messages import AIMessage, AIMessageChunk @@ -112,8 +104,7 @@ def invoke(self, *args: Any, **kwargs: Any) -> AIMessage: elif isinstance(event, OnToolEndEvent): tool_calls.append(event.data.model_dump()) return AIMessage( - content=response_content, - additional_kwargs={"tool_calls_data": tool_calls} + content=response_content, additional_kwargs={"tool_calls_data": tool_calls} ) def __call__(self, *args: Any, **kwargs: Any) -> Any: @@ -125,8 +116,9 @@ def custom_chain(func: Callable) -> CustomChain: """ Decorator function that wraps a callable in a CustomChain instance. """ + @wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: return func(*args, **kwargs) - return CustomChain(wrapper) \ No newline at end of file + return CustomChain(wrapper) diff --git a/gemini/sample-apps/conversational-genai-app-template/app/utils/tracing.py b/gemini/sample-apps/conversational-genai-app-template/app/utils/tracing.py index 63fd4293fb8..c770ae836f6 100644 --- a/gemini/sample-apps/conversational-genai-app-template/app/utils/tracing.py +++ b/gemini/sample-apps/conversational-genai-app-template/app/utils/tracing.py @@ -27,17 +27,18 @@ class CloudTraceLoggingSpanExporter(CloudTraceSpanExporter): """ An extended version of CloudTraceSpanExporter that logs span data to Google Cloud Logging and handles large attribute values by storing them in Google Cloud Storage. - + This class helps bypass the 256 character limit of Cloud Trace for attribute values by leveraging Cloud Logging (which has a 256KB limit) and Cloud Storage for larger payloads. """ + def __init__( self, logging_client: Optional[gcp_logging.Client] = None, storage_client: Optional[storage.Client] = None, bucket_name: Optional[str] = None, debug: bool = False, - **kwargs: Any + **kwargs: Any, ) -> None: """ Initialize the exporter with Google Cloud clients and configuration. @@ -50,14 +51,15 @@ def __init__( """ super().__init__(**kwargs) self.debug = debug - self.logging_client = logging_client or gcp_logging.Client(project=self.project_id) + self.logging_client = logging_client or gcp_logging.Client( + project=self.project_id + ) self.logger = self.logging_client.logger(__name__) self.storage_client = storage_client or storage.Client(project=self.project_id) self.bucket_name = bucket_name or f"{self.project_id}-logs-data" self._ensure_bucket_exists() self.bucket = self.storage_client.bucket(self.bucket_name) - def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: """ Export the spans to Google Cloud Logging and Cloud Trace. @@ -67,13 +69,13 @@ def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: """ for span in spans: span_context = span.get_span_context() - trace_id = format(span_context.trace_id, 'x') - span_id = format(span_context.span_id, 'x') + trace_id = format(span_context.trace_id, "x") + span_id = format(span_context.span_id, "x") span_dict = json.loads(span.to_json()) span_dict["trace"] = f"projects/{self.project_id}/traces/{trace_id}" span_dict["span_id"] = span_id - + span_dict = self._process_large_attributes( span_dict=span_dict, span_id=span_id ) @@ -93,7 +95,6 @@ def _ensure_bucket_exists(self) -> None: logging.info(f"Bucket {self.bucket_name} not detected. Creating it now.") self.storage_client.create_bucket(self.bucket_name) - def store_in_gcs(self, content: str, span_id: str) -> str: """ Initiate storing large content in Google Cloud Storage/ @@ -105,15 +106,14 @@ def store_in_gcs(self, content: str, span_id: str) -> str: blob_name = f"spans/{span_id}.json" blob = self.bucket.blob(blob_name) - blob.upload_from_string(content, 'application/json') + blob.upload_from_string(content, "application/json") return f"gs://{self.bucket_name}/{blob_name}" - def _process_large_attributes(self, span_dict: dict, span_id: str) -> dict: """ - Process large attribute values by storing them in GCS if they exceed the size + Process large attribute values by storing them in GCS if they exceed the size limit of Google Cloud Logging. - + :param span_dict: The span data dictionary :param trace_id: The trace ID :param span_id: The span ID @@ -123,14 +123,16 @@ def _process_large_attributes(self, span_dict: dict, span_id: str) -> dict: if len(json.dumps(attributes).encode()) > 255 * 1024: # 250 KB # Separate large payload from other attributes attributes_payload = { - k: v for k, v in attributes.items() + k: v + for k, v in attributes.items() if "traceloop.association.properties" not in k } attributes_retain = { - k: v for k, v in attributes.items() + k: v + for k, v in attributes.items() if "traceloop.association.properties" in k } - + # Store large payload in GCS gcs_uri = self.store_in_gcs(json.dumps(attributes_payload), span_id) attributes_retain["uri_payload"] = gcs_uri @@ -138,11 +140,11 @@ def _process_large_attributes(self, span_dict: dict, span_id: str) -> dict: f"https://storage.mtls.cloud.google.com/" f"{self.bucket_name}/spans/{span_id}.json" ) - + span_dict["attributes"] = attributes_retain logging.info( "Length of payload span above 250 KB, storing attributes in GCS " "to avoid large log entry errors" ) - - return span_dict \ No newline at end of file + + return span_dict diff --git a/gemini/sample-apps/conversational-genai-app-template/deployment/README.md b/gemini/sample-apps/conversational-genai-app-template/deployment/README.md index ac00e931db6..ea24a1ee242 100644 --- a/gemini/sample-apps/conversational-genai-app-template/deployment/README.md +++ b/gemini/sample-apps/conversational-genai-app-template/deployment/README.md @@ -1,5 +1,5 @@ ## Deployment README.md -This folder contains the infrastructure-as-code and CI/CD pipeline configurations for deploying a conversational GenAI application on Google Cloud. +This folder contains the infrastructure-as-code and CI/CD pipeline configurations for deploying a conversational Generative AI application on Google Cloud. The application leverages [**Terraform**](http://terraform.io) to define and provision the underlying infrastructure, while [**Cloud Build**](https://cloud.google.com/build/) orchestrates the continuous integration and continuous deployment (CI/CD) pipeline. diff --git a/gemini/sample-apps/conversational-genai-app-template/deployment/terraform/artifact_registry.tf b/gemini/sample-apps/conversational-genai-app-template/deployment/terraform/artifact_registry.tf index fed1c2c7306..028aa4f05b3 100644 --- a/gemini/sample-apps/conversational-genai-app-template/deployment/terraform/artifact_registry.tf +++ b/gemini/sample-apps/conversational-genai-app-template/deployment/terraform/artifact_registry.tf @@ -1,7 +1,7 @@ resource "google_artifact_registry_repository" "my-repo" { location = "us-central1" repository_id = var.artifact_registry_repo_name - description = "Repo for GenAI applications" + description = "Repo for Generative AI applications" format = "DOCKER" project = var.cicd_runner_project_id depends_on = [resource.google_project_service.cicd_services, resource.google_project_service.shared_services] diff --git a/gemini/sample-apps/conversational-genai-app-template/deployment/terraform/dev/service_accounts.tf b/gemini/sample-apps/conversational-genai-app-template/deployment/terraform/dev/service_accounts.tf index cbf04d39f44..935bd815f50 100644 --- a/gemini/sample-apps/conversational-genai-app-template/deployment/terraform/dev/service_accounts.tf +++ b/gemini/sample-apps/conversational-genai-app-template/deployment/terraform/dev/service_accounts.tf @@ -1,5 +1,5 @@ resource "google_service_account" "cloud_run_app_sa" { account_id = var.cloud_run_app_sa_name - display_name = "Cloud Run GenAI app SA" + display_name = "Cloud Run Generative AI app SA" project = var.dev_project_id } diff --git a/gemini/sample-apps/conversational-genai-app-template/deployment/terraform/service_accounts.tf b/gemini/sample-apps/conversational-genai-app-template/deployment/terraform/service_accounts.tf index 1e2b8fa8753..b8eec7d7474 100644 --- a/gemini/sample-apps/conversational-genai-app-template/deployment/terraform/service_accounts.tf +++ b/gemini/sample-apps/conversational-genai-app-template/deployment/terraform/service_accounts.tf @@ -9,7 +9,7 @@ resource "google_service_account" "cloud_run_app_sa" { for_each = local.project_ids account_id = var.cloud_run_app_sa_name - display_name = "Cloud Run GenAI app SA" + display_name = "Cloud Run Generative AI app SA" project = each.value depends_on = [resource.google_project_service.cicd_services, resource.google_project_service.shared_services] } diff --git a/gemini/sample-apps/conversational-genai-app-template/notebooks/getting_started.ipynb b/gemini/sample-apps/conversational-genai-app-template/notebooks/getting_started.ipynb index a1639a21469..a14f0e332e6 100644 --- a/gemini/sample-apps/conversational-genai-app-template/notebooks/getting_started.ipynb +++ b/gemini/sample-apps/conversational-genai-app-template/notebooks/getting_started.ipynb @@ -56,7 +56,7 @@ "It covers:\n", "\n", "1. Creating chains using different methods:\n", - " - Langchain LCEL (LangChain Expression Language)\n", + " - LangChain LCEL (LangChain Expression Language)\n", " - LangGraph\n", " - Custom Python code\n", "2. Evaluating these chains\n", @@ -236,6 +236,7 @@ "outputs": [], "source": [ "import sys\n", + "\n", "sys.path.append(\"../\")" ] }, @@ -276,7 +277,7 @@ "### Input Interface\n", "\n", "The chain must provide an `astream_events` method that accepts a dictionary with a \"messages\" key.\n", - "The \"messages\" value should be a list of alternating Langchain [HumanMessage](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.human.HumanMessage.html) and [AIMessage](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.ai.AIMessage.html) objects.\n", + "The \"messages\" value should be a list of alternating LangChain [HumanMessage](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.human.HumanMessage.html) and [AIMessage](https://api.python.langchain.com/en/latest/messages/langchain_core.messages.ai.AIMessage.html) objects.\n", "\n", "For example:\n", "\n", @@ -304,9 +305,9 @@ "\n", "### Output Interface\n", "\n", - "All chains use the [Langchain Astream Events (v2) API](https://python.langchain.com/v0.1/docs/expression_language/streaming/#using-stream-events). This API supports various use cases (simple chains, RAG, Agents). This API emits asynchronous events that can be used to stream the chain's output.\n", + "All chains use the [LangChain Astream Events (v2) API](https://python.langchain.com/v0.1/docs/expression_language/streaming/#using-stream-events). This API supports various use cases (simple chains, RAG, Agents). This API emits asynchronous events that can be used to stream the chain's output.\n", "\n", - "Langchain chains (LCEL, Langraph) automatically implement the `astream_events` API. \n", + "LangChain chains (LCEL, Langraph) automatically implement the `astream_events` API. \n", "\n", "We provide examples of emitting `astream_events`-compatible events with custom Python code, allowing implementation with other SDKs (e.g., Vertex AI, LLamaIndex).\n", "\n", @@ -330,7 +331,13 @@ "metadata": {}, "outputs": [], "source": [ - "SUPPORTED_EVENTS = [\"on_tool_start\", \"on_tool_end\",\"on_retriever_start\", \"on_retriever_end\", \"on_chat_model_stream\"]" + "SUPPORTED_EVENTS = [\n", + " \"on_tool_start\",\n", + " \"on_tool_end\",\n", + " \"on_retriever_start\",\n", + " \"on_retriever_end\",\n", + " \"on_chat_model_stream\",\n", + "]" ] }, { @@ -347,14 +354,14 @@ "metadata": {}, "outputs": [], "source": [ - "llm = ChatVertexAI(model_name=\"gemini-1.5-flash-001\", temperature=0)\n" + "llm = ChatVertexAI(model_name=\"gemini-1.5-flash-001\", temperature=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Leveraging Langchain LCEL for Efficient Chain Composition\n", + "### Leveraging LangChain LCEL for Efficient Chain Composition\n", "\n", "LangChain Expression Language (LCEL) provides a declarative approach to composing chains seamlessly. Key benefits include:\n", "\n", @@ -435,22 +442,26 @@ " return \"It's 60 degrees and foggy.\"\n", " return \"It's 90 degrees and sunny.\"\n", "\n", + "\n", "tools = [search]\n", "\n", "# 2. Set up the language model\n", "llm = llm.bind_tools(tools)\n", "\n", + "\n", "# 3. Define workflow components\n", "def should_continue(state: MessagesState) -> Literal[\"tools\", END]:\n", " \"\"\"Determines whether to use tools or end the conversation.\"\"\"\n", - " last_message = state['messages'][-1]\n", + " last_message = state[\"messages\"][-1]\n", " return \"tools\" if last_message.tool_calls else END\n", "\n", + "\n", "async def call_model(state: MessagesState, config: RunnableConfig):\n", " \"\"\"Calls the language model and returns the response.\"\"\"\n", - " response = llm.invoke(state['messages'], config)\n", + " response = llm.invoke(state[\"messages\"], config)\n", " return {\"messages\": response}\n", "\n", + "\n", "# 4. Create the workflow graph\n", "workflow = StateGraph(MessagesState)\n", "workflow.add_node(\"agent\", call_model)\n", @@ -459,7 +470,7 @@ "\n", "# 5. Define graph edges\n", "workflow.add_conditional_edges(\"agent\", should_continue)\n", - "workflow.add_edge(\"tools\", 'agent')\n", + "workflow.add_edge(\"tools\", \"agent\")\n", "\n", "# 6. Compile the workflow\n", "chain = workflow.compile()" @@ -516,7 +527,7 @@ "\n", "vector_store = get_vector_store(embedding=embedding)\n", "retriever = vector_store.as_retriever(search_kwargs={\"k\": 20})\n", - "reranker = VertexAIRank(\n", + "compressor = VertexAIRank(\n", " project_id=PROJECT_ID,\n", " location_id=\"global\",\n", " ranking_config=\"default_ranking_config\",\n", @@ -529,9 +540,11 @@ "\n", "\n", "@custom_chain\n", - "def chain(input: Dict[str, Any], **kwargs) -> Iterator[OnToolEndEvent | OnChatModelStreamEvent]:\n", + "def chain(\n", + " input: Dict[str, Any], **kwargs\n", + ") -> Iterator[OnToolEndEvent | OnChatModelStreamEvent]:\n", " \"\"\"\n", - " Implements a RAG QA chain. Decorated with `custom_chain` to offer Langchain compatible astream_events\n", + " Implements a RAG QA chain. Decorated with `custom_chain` to offer LangChain compatible astream_events\n", " and invoke interface and OpenTelemetry tracing.\n", " \"\"\"\n", " # Generate optimized query\n", @@ -539,14 +552,14 @@ "\n", " # Retrieve and rank documents\n", " retrieved_docs = retriever.get_relevant_documents(query)\n", - " ranked_docs = reranker.compress_documents(documents=retrieved_docs, query=query)\n", + " ranked_docs = compressor.compress_documents(documents=retrieved_docs, query=query)\n", "\n", " # Yield tool results metadata\n", " yield OnToolEndEvent(data={\"input\": {\"query\": query}, \"output\": ranked_docs})\n", "\n", " # Stream LLM response\n", " for chunk in response_chain.stream(\n", - " input={\"messages\": input[\"messages\"], \"relevant_documents\": ranked_docs}\n", + " input={\"messages\": input[\"messages\"], \"relevant_documents\": ranked_docs}\n", " ):\n", " yield OnChatModelStreamEvent(data={\"chunk\": chunk})" ] @@ -703,7 +716,7 @@ "source": [ "scored_data[\"user\"] = scored_data[\"human_message\"].apply(lambda x: x[\"content\"])\n", "scored_data[\"reference\"] = scored_data[\"ai_message\"].apply(lambda x: x[\"content\"])\n", - "scored_data\n" + "scored_data" ] }, { @@ -771,7 +784,7 @@ "metadata": {}, "outputs": [], "source": [ - "experiment_name = \"rapid-eval-langchain-eval\" # @param {type:\"string\"}\n" + "experiment_name = \"rapid-eval-langchain-eval\" # @param {type:\"string\"}" ] }, { @@ -791,10 +804,15 @@ "metadata": {}, "outputs": [], "source": [ - "metrics = [\"fluency\", \"safety\", custom_faithfulness_metric]\n", + "metrics = [\"fluency\", \"safety\", custom_faithfulness_metric]\n", "\n", "metrics = [custom_faithfulness_metric]\n", - "eval_task = EvalTask(dataset=scored_data, metrics=metrics, experiment=experiment_name, metric_column_mapping={\"user\":\"prompt\"} )\n", + "eval_task = EvalTask(\n", + " dataset=scored_data,\n", + " metrics=metrics,\n", + " experiment=experiment_name,\n", + " metric_column_mapping={\"user\": \"prompt\"},\n", + ")\n", "eval_result = eval_task.evaluate()" ] }, diff --git a/gemini/sample-apps/conversational-genai-app-template/poetry.lock b/gemini/sample-apps/conversational-genai-app-template/poetry.lock index 1376458260d..71b7de10f27 100644 --- a/gemini/sample-apps/conversational-genai-app-template/poetry.lock +++ b/gemini/sample-apps/conversational-genai-app-template/poetry.lock @@ -4069,7 +4069,7 @@ opentelemetry-semantic-conventions-ai = "0.4.1" [[package]] name = "opentelemetry-instrumentation-langchain" version = "0.30.1" -description = "OpenTelemetry Langchain instrumentation" +description = "OpenTelemetry LangChain instrumentation" optional = false python-versions = "<4,>=3.9" files = [ diff --git a/gemini/sample-apps/conversational-genai-app-template/streamlit/side_bar.py b/gemini/sample-apps/conversational-genai-app-template/streamlit/side_bar.py index 9c61d3cddfa..90b8b2b0bae 100644 --- a/gemini/sample-apps/conversational-genai-app-template/streamlit/side_bar.py +++ b/gemini/sample-apps/conversational-genai-app-template/streamlit/side_bar.py @@ -28,7 +28,6 @@ class SideBar: - def __init__(self, st) -> None: self.st = st @@ -36,28 +35,36 @@ def init_side_bar(self): with self.st.sidebar: self.url_input_field = self.st.text_input( label="Service URL", - value=os.environ.get("SERVICE_URL", DEFAULT_BASE_URL) + value=os.environ.get("SERVICE_URL", DEFAULT_BASE_URL), ) self.should_authenticate_request = self.st.checkbox( label="Authenticate request", value=False, - help="If checked, any request to the server will contain an" + help="If checked, any request to the server will contain an" "Identity token to allow authentication. " "See the Cloud Run documentation to know more about authentication:" - "https://cloud.google.com/run/docs/authenticating/service-to-service" + "https://cloud.google.com/run/docs/authenticating/service-to-service", ) col1, col2, col3 = self.st.columns(3) with col1: if self.st.button("+ New chat"): - if len(self.st.session_state.user_chats[self.st.session_state['session_id']][ - "messages"]) > 0: + if ( + len( + self.st.session_state.user_chats[ + self.st.session_state["session_id"] + ]["messages"] + ) + > 0 + ): self.st.session_state.run_id = None - self.st.session_state['session_id'] = str(uuid.uuid4()) + self.st.session_state["session_id"] = str(uuid.uuid4()) self.st.session_state.session_db.get_session( - session_id=self.st.session_state['session_id'], + session_id=self.st.session_state["session_id"], ) - self.st.session_state.user_chats[self.st.session_state['session_id']] = { + self.st.session_state.user_chats[ + self.st.session_state["session_id"] + ] = { "title": EMPTY_CHAT_NAME, "messages": [], } @@ -66,16 +73,20 @@ def init_side_bar(self): if self.st.button("Delete chat"): self.st.session_state.run_id = None self.st.session_state.session_db.clear() - self.st.session_state.user_chats.pop(self.st.session_state['session_id']) + self.st.session_state.user_chats.pop( + self.st.session_state["session_id"] + ) if len(self.st.session_state.user_chats) > 0: chat_id = list(self.st.session_state.user_chats.keys())[0] - self.st.session_state['session_id'] = chat_id + self.st.session_state["session_id"] = chat_id self.st.session_state.session_db.get_session( - session_id=self.st.session_state['session_id'], + session_id=self.st.session_state["session_id"], ) else: - self.st.session_state['session_id'] = str(uuid.uuid4()) - self.st.session_state.user_chats[self.st.session_state['session_id']] = { + self.st.session_state["session_id"] = str(uuid.uuid4()) + self.st.session_state.user_chats[ + self.st.session_state["session_id"] + ] = { "title": EMPTY_CHAT_NAME, "messages": [], } @@ -84,47 +95,54 @@ def init_side_bar(self): save_chat(self.st) self.st.subheader("Recent") # Style the heading - + all_chats = list(reversed(self.st.session_state.user_chats.items())) for chat_id, chat in all_chats[:NUM_CHAT_IN_RECENT]: if self.st.button(chat["title"], key=chat_id): self.st.session_state.run_id = None - self.st.session_state['session_id'] = chat_id + self.st.session_state["session_id"] = chat_id self.st.session_state.session_db.get_session( - session_id=self.st.session_state['session_id'], + session_id=self.st.session_state["session_id"], ) with self.st.expander("Other chats"): for chat_id, chat in all_chats[NUM_CHAT_IN_RECENT:]: if self.st.button(chat["title"], key=chat_id): self.st.session_state.run_id = None - self.st.session_state['session_id'] = chat_id + self.st.session_state["session_id"] = chat_id self.st.session_state.session_db.get_session( - session_id=self.st.session_state['session_id'], + session_id=self.st.session_state["session_id"], ) self.st.divider() self.st.header("Upload files from local") bucket_name = self.st.text_input( label="GCS Bucket for upload", - value=os.environ.get("BUCKET_NAME","gs://your-bucket-name") + value=os.environ.get("BUCKET_NAME", "gs://your-bucket-name"), ) - if 'checkbox_state' not in self.st.session_state: + if "checkbox_state" not in self.st.session_state: self.st.session_state.checkbox_state = True - + self.st.session_state.checkbox_state = self.st.checkbox( - "Upload to GCS first (suggested)", - value=True, - help=HELP_GCS_CHECKBOX + "Upload to GCS first (suggested)", value=True, help=HELP_GCS_CHECKBOX ) self.uploaded_files = self.st.file_uploader( - label="Send files from local", accept_multiple_files=True, + label="Send files from local", + accept_multiple_files=True, key=f"uploader_images_{self.st.session_state.uploader_key}", type=[ - "png", "jpg", "jpeg", "txt", "docx", - "pdf", "rtf", "csv", "tsv", "xlsx" - ] + "png", + "jpg", + "jpeg", + "txt", + "docx", + "pdf", + "rtf", + "csv", + "tsv", + "xlsx", + ], ) if self.uploaded_files and self.st.session_state.checkbox_state: upload_files_to_gcs(self.st, bucket_name, self.uploaded_files) @@ -136,7 +154,7 @@ def init_side_bar(self): "GCS uris (comma-separated)", value=self.st.session_state["gcs_uris_to_be_sent"], key=f"upload_text_area_{self.st.session_state.uploader_key}", - help=HELP_MESSAGE_MULTIMODALITY + help=HELP_MESSAGE_MULTIMODALITY, ) - - self.st.caption(f"Note: {HELP_MESSAGE_MULTIMODALITY}") \ No newline at end of file + + self.st.caption(f"Note: {HELP_MESSAGE_MULTIMODALITY}") diff --git a/gemini/sample-apps/conversational-genai-app-template/streamlit/streamlit_app.py b/gemini/sample-apps/conversational-genai-app-template/streamlit/streamlit_app.py index 11cad777de2..12947fac963 100644 --- a/gemini/sample-apps/conversational-genai-app-template/streamlit/streamlit_app.py +++ b/gemini/sample-apps/conversational-genai-app-template/streamlit/streamlit_app.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial import json import uuid -from functools import partial from langchain_core.messages import HumanMessage from side_bar import SideBar +import streamlit as st from streamlit_feedback import streamlit_feedback from style.app_markdown import markdown_str from utils.local_chat_history import LocalChatMessageHistory @@ -25,8 +26,6 @@ from utils.multimodal_utils import format_content, get_parts_from_files from utils.stream_handler import Client, StreamHandler, get_chain_response -import streamlit as st - USER = "my_user" EMPTY_CHAT_NAME = "Empty chat" @@ -34,14 +33,14 @@ page_title="Playground", layout="wide", initial_sidebar_state="auto", - menu_items=None + menu_items=None, ) st.title("Playground") st.markdown(markdown_str, unsafe_allow_html=True) # First time Init of session variables if "user_chats" not in st.session_state: - st.session_state['session_id'] = str(uuid.uuid4()) + st.session_state["session_id"] = str(uuid.uuid4()) st.session_state.uploader_key = 0 st.session_state.run_id = None st.session_state.user_id = USER @@ -49,37 +48,41 @@ st.session_state["gcs_uris_to_be_sent"] = "" st.session_state.modified_prompt = None st.session_state.session_db = LocalChatMessageHistory( - session_id=st.session_state['session_id'], - user_id=st.session_state['user_id'], + session_id=st.session_state["session_id"], + user_id=st.session_state["user_id"], ) st.session_state.user_chats = st.session_state.session_db.get_all_conversations() - st.session_state.user_chats[st.session_state['session_id']] = { + st.session_state.user_chats[st.session_state["session_id"]] = { "title": EMPTY_CHAT_NAME, "messages": [], - } + } side_bar = SideBar(st=st) side_bar.init_side_bar() -client = Client(url=side_bar.url_input_field, authenticate_request=side_bar.should_authenticate_request) +client = Client( + url=side_bar.url_input_field, + authenticate_request=side_bar.should_authenticate_request, +) # Write all messages of current conversation -messages = st.session_state.user_chats[st.session_state['session_id']]["messages"] +messages = st.session_state.user_chats[st.session_state["session_id"]]["messages"] for i, message in enumerate(messages): with st.chat_message(message["type"]): if message["type"] == "ai": - if message.get("tool_calls") and len(message.get("tool_calls")) > 0: - tool_expander = st.expander( - label="Tool Calls:", - expanded=False) + tool_expander = st.expander(label="Tool Calls:", expanded=False) with tool_expander: for index, tool_call in enumerate(message["tool_calls"]): # ruff: noqa: E501 - tool_call_output = message["additional_kwargs"]["tool_calls_outputs"][index] - msg = f"\n\nEnding tool: `{tool_call['name']}` with\n **args:**\n" \ - f"```\n{json.dumps(tool_call['args'], indent=2)}\n```\n" \ - f"\n\n**output:**\n " \ - f"```\n{json.dumps(tool_call_output['output'], indent=2)}\n```" + tool_call_output = message["additional_kwargs"][ + "tool_calls_outputs" + ][index] + msg = ( + f"\n\nEnding tool: `{tool_call['name']}` with\n **args:**\n" + f"```\n{json.dumps(tool_call['args'], indent=2)}\n```\n" + f"\n\n**output:**\n " + f"```\n{json.dumps(tool_call_output['output'], indent=2)}\n```" + ) st.markdown(msg, unsafe_allow_html=True) st.markdown(format_content(message["content"]), unsafe_allow_html=True) @@ -88,29 +91,25 @@ refresh_button = f"{i}_refresh" delete_button = f"{i}_delete" content = message["content"] - - if isinstance(message["content"],list): + + if isinstance(message["content"], list): content = message["content"][-1]["text"] with col1: - st.button( - label="✎", - key=edit_button, - type="primary" - ) + st.button(label="✎", key=edit_button, type="primary") if message["type"] == "human": with col2: st.button( label="⟳", key=refresh_button, type="primary", - on_click=partial(MessageEditing.refresh_message, st, i, content) + on_click=partial(MessageEditing.refresh_message, st, i, content), ) with col3: st.button( label="X", key=delete_button, type="primary", - on_click=partial(MessageEditing.delete_message, st, i) + on_click=partial(MessageEditing.delete_message, st, i), ) if st.session_state[edit_button]: @@ -118,55 +117,52 @@ "Edit your message:", value=content, key=f"edit_box_{i}", - on_change=partial(MessageEditing.edit_message, st, i, message["type"])) + on_change=partial(MessageEditing.edit_message, st, i, message["type"]), + ) # Handle new (or modified) user prompt and response prompt = st.chat_input() if prompt is None: prompt = st.session_state.modified_prompt - + if prompt: st.session_state.modified_prompt = None parts = get_parts_from_files( upload_gcs_checkbox=st.session_state.checkbox_state, - uploaded_files=side_bar.uploaded_files, - gcs_uris=side_bar.gcs_uris + uploaded_files=side_bar.uploaded_files, + gcs_uris=side_bar.gcs_uris, ) st.session_state["gcs_uris_to_be_sent"] = "" - parts.append( - { - "type": "text", - "text": prompt - } + parts.append({"type": "text", "text": prompt}) + st.session_state.user_chats[st.session_state["session_id"]]["messages"].append( + HumanMessage(content=parts).model_dump() ) - st.session_state.user_chats[st.session_state['session_id']]["messages"].append( - HumanMessage(content=parts).model_dump()) human_message = st.chat_message("human") with human_message: existing_user_input = format_content(parts) user_input = st.markdown(existing_user_input, unsafe_allow_html=True) - + ai_message = st.chat_message("ai") with ai_message: status = st.status("Generating answerπŸ€–") stream_handler = StreamHandler(st=st) - get_chain_response( - st=st, - client=client, - stream_handler=stream_handler - ) + get_chain_response(st=st, client=client, stream_handler=stream_handler) status.update(label="Finished!", state="complete", expanded=False) - if st.session_state.user_chats[st.session_state['session_id']][ - "title"] == EMPTY_CHAT_NAME: + if ( + st.session_state.user_chats[st.session_state["session_id"]]["title"] + == EMPTY_CHAT_NAME + ): st.session_state.session_db.set_title( - st.session_state.user_chats[st.session_state['session_id']] + st.session_state.user_chats[st.session_state["session_id"]] ) - st.session_state.session_db.upsert_session(st.session_state.user_chats[st.session_state['session_id']]) + st.session_state.session_db.upsert_session( + st.session_state.user_chats[st.session_state["session_id"]] + ) if len(parts) > 1: st.session_state.uploader_key += 1 st.rerun() @@ -175,12 +171,10 @@ feedback = streamlit_feedback( feedback_type="faces", optional_text_label="[Optional] Please provide an explanation", - key=f"feedback-{st.session_state.run_id}" + key=f"feedback-{st.session_state.run_id}", ) if feedback is not None: client.log_feedback( feedback_dict=feedback, run_id=st.session_state.run_id, ) - - diff --git a/gemini/sample-apps/conversational-genai-app-template/streamlit/style/app_markdown.py b/gemini/sample-apps/conversational-genai-app-template/streamlit/style/app_markdown.py index b91a810e9c3..b1332580450 100644 --- a/gemini/sample-apps/conversational-genai-app-template/streamlit/style/app_markdown.py +++ b/gemini/sample-apps/conversational-genai-app-template/streamlit/style/app_markdown.py @@ -34,4 +34,4 @@ color: !important; } -""" \ No newline at end of file +""" diff --git a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/local_chat_history.py b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/local_chat_history.py index 6e3967ccd39..ebbe0110ed7 100644 --- a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/local_chat_history.py +++ b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/local_chat_history.py @@ -12,21 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os from datetime import datetime +import os -import yaml from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import HumanMessage from utils.title_summary import chain_title +import yaml class LocalChatMessageHistory(BaseChatMessageHistory): def __init__( - self, - user_id: str, - session_id: str = "default", - base_dir: str = ".streamlit_chats" + self, + user_id: str, + session_id: str = "default", + base_dir: str = ".streamlit_chats", ) -> None: self.user_id = user_id self.session_id = session_id @@ -43,12 +43,13 @@ def get_session(self, session_id): def get_all_conversations(self): conversations = {} for filename in os.listdir(self.user_dir): - if filename.endswith('.yaml'): + if filename.endswith(".yaml"): file_path = os.path.join(self.user_dir, filename) - with open(file_path, 'r') as f: + with open(file_path, "r") as f: conversation = yaml.safe_load(f) if not isinstance(conversation, list) or len(conversation) > 1: - raise ValueError(f"""Invalid format in {file_path}. + raise ValueError( + f"""Invalid format in {file_path}. YAML file can only contain one conversation with the following structure. - messages: @@ -61,17 +62,18 @@ def get_all_conversations(self): conversation["title"] = filename conversations[filename[:-5]] = conversation return dict( - sorted(conversations.items(), key=lambda x: x[1].get('update_time', ''))) + sorted(conversations.items(), key=lambda x: x[1].get("update_time", "")) + ) def upsert_session(self, session) -> None: - session['update_time'] = datetime.now().isoformat() - with open(self.session_file, 'w') as f: + session["update_time"] = datetime.now().isoformat() + with open(self.session_file, "w") as f: yaml.dump( [session], f, allow_unicode=True, default_flow_style=False, - encoding='utf-8' + encoding="utf-8", ) def set_title(self, session) -> None: @@ -100,4 +102,4 @@ def set_title(self, session) -> None: def clear(self) -> None: if os.path.exists(self.session_file): - os.remove(self.session_file) \ No newline at end of file + os.remove(self.session_file) diff --git a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/message_editing.py b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/message_editing.py index e696714657d..7b829d67d26 100644 --- a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/message_editing.py +++ b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/message_editing.py @@ -12,27 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. -class MessageEditing: +class MessageEditing: @staticmethod def edit_message(st, button_idx, message_type): button_id = f"edit_box_{button_idx}" if message_type == "human": - messages = st.session_state.user_chats[st.session_state['session_id']]["messages"] - st.session_state.user_chats[st.session_state['session_id']][ - "messages"] = messages[:button_idx] + messages = st.session_state.user_chats[st.session_state["session_id"]][ + "messages" + ] + st.session_state.user_chats[st.session_state["session_id"]][ + "messages" + ] = messages[:button_idx] st.session_state.modified_prompt = st.session_state[button_id] else: - st.session_state.user_chats[st.session_state['session_id']]["messages"][ - button_idx]["content"] = st.session_state[button_id] - + st.session_state.user_chats[st.session_state["session_id"]]["messages"][ + button_idx + ]["content"] = st.session_state[button_id] + @staticmethod def refresh_message(st, button_idx, content): - messages = st.session_state.user_chats[st.session_state['session_id']]["messages"] - st.session_state.user_chats[st.session_state['session_id']]["messages"] = messages[:button_idx] + messages = st.session_state.user_chats[st.session_state["session_id"]][ + "messages" + ] + st.session_state.user_chats[st.session_state["session_id"]][ + "messages" + ] = messages[:button_idx] st.session_state.modified_prompt = content @staticmethod def delete_message(st, button_idx): - messages = st.session_state.user_chats[st.session_state['session_id']]["messages"] - st.session_state.user_chats[st.session_state['session_id']]["messages"] = messages[:button_idx] \ No newline at end of file + messages = st.session_state.user_chats[st.session_state["session_id"]][ + "messages" + ] + st.session_state.user_chats[st.session_state["session_id"]][ + "messages" + ] = messages[:button_idx] diff --git a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/multimodal_utils.py b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/multimodal_utils.py index 94d35231058..4c1c8a49738 100644 --- a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/multimodal_utils.py +++ b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/multimodal_utils.py @@ -17,12 +17,16 @@ from google.cloud import storage -HELP_MESSAGE_MULTIMODALITY = "To ensure Gemini models can access the URIs you " \ - "provide, store all URIs in buckets within the same " \ - "GCP Project that Gemini uses." +HELP_MESSAGE_MULTIMODALITY = ( + "To ensure Gemini models can access the URIs you " + "provide, store all URIs in buckets within the same " + "GCP Project that Gemini uses." +) -HELP_GCS_CHECKBOX = "Enabling GCS upload will increase app performance by avoiding to" \ - " pass large byte strings to the model" +HELP_GCS_CHECKBOX = ( + "Enabling GCS upload will increase app performance by avoiding to" + " pass large byte strings to the model" +) def format_content(content): @@ -41,9 +45,12 @@ def format_content(content): if part["type"] == "image_url": image_url = part["image_url"]["url"] image_markdown = f'' - markdown = markdown + f""" + markdown = ( + markdown + + f""" - {image_markdown} """ + ) if part["type"] == "media": # Local other media if "data" in part: @@ -54,19 +61,27 @@ def format_content(content): if "image" in part["mime_type"]: image_url = gs_uri_to_https_url(part["file_uri"]) image_markdown = f'' - markdown = markdown + f""" + markdown = ( + markdown + + f""" - {image_markdown} """ + ) # GCS other media else: - image_url = gs_uri_to_https_url(part['file_uri']) - markdown = markdown + f"- Remote media: " \ - f"[{part['file_uri']}]({image_url})\n" - markdown = markdown + f""" + image_url = gs_uri_to_https_url(part["file_uri"]) + markdown = ( + markdown + f"- Remote media: " + f"[{part['file_uri']}]({image_url})\n" + ) + markdown = ( + markdown + + f""" {text}""" + ) return markdown - + def get_gcs_blob_mime_type(gcs_uri): """Fetches the MIME type (content type) of a Google Cloud Storage blob. @@ -103,17 +118,17 @@ def get_parts_from_files(upload_gcs_checkbox, uploaded_files, gcs_uris): content = { "type": "image_url", "image_url": { - "url": f"data:{uploaded_file.type};base64," \ - f"{base64.b64encode(im_bytes).decode('utf-8')}" + "url": f"data:{uploaded_file.type};base64," + f"{base64.b64encode(im_bytes).decode('utf-8')}" }, "file_name": uploaded_file.name, } else: content = { "type": "media", - "data": base64.b64encode(im_bytes).decode('utf-8'), + "data": base64.b64encode(im_bytes).decode("utf-8"), "file_name": uploaded_file.name, - "mime_type": uploaded_file.type + "mime_type": uploaded_file.type, } parts.append(content) @@ -122,11 +137,12 @@ def get_parts_from_files(upload_gcs_checkbox, uploaded_files, gcs_uris): content = { "type": "media", "file_uri": uri, - "mime_type": get_gcs_blob_mime_type(uri) + "mime_type": get_gcs_blob_mime_type(uri), } parts.append(content) return parts + def upload_bytes_to_gcs(bucket_name, blob_name, file_bytes, content_type=None): """Uploads a bytes object to Google Cloud Storage and returns the GCS URI. @@ -177,17 +193,17 @@ def gs_uri_to_https_url(gs_uri): def upload_files_to_gcs(st, bucket_name, files_to_upload): - bucket_name = bucket_name.replace("gs://", "") - uploaded_uris = [] - for file in files_to_upload: - if file: - file_bytes = file.read() - gcs_uri = upload_bytes_to_gcs( - bucket_name=bucket_name, - blob_name=file.name, - file_bytes=file_bytes, - content_type=file.type - ) - uploaded_uris.append(gcs_uri) - st.session_state.uploader_key += 1 - st.session_state["gcs_uris_to_be_sent"] = ",".join(uploaded_uris) + bucket_name = bucket_name.replace("gs://", "") + uploaded_uris = [] + for file in files_to_upload: + if file: + file_bytes = file.read() + gcs_uri = upload_bytes_to_gcs( + bucket_name=bucket_name, + blob_name=file.name, + file_bytes=file_bytes, + content_type=file.type, + ) + uploaded_uris.append(gcs_uri) + st.session_state.uploader_key += 1 + st.session_state["gcs_uris_to_be_sent"] = ",".join(uploaded_uris) diff --git a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/stream_handler.py b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/stream_handler.py index 9c560687152..bfe0e07c12d 100644 --- a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/stream_handler.py +++ b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/stream_handler.py @@ -17,12 +17,11 @@ from urllib.parse import urljoin import google.auth +from google.auth.exceptions import DefaultCredentialsError import google.auth.transport.requests import google.oauth2.id_token -import requests -from google.auth.exceptions import DefaultCredentialsError from langchain_core.messages import AIMessage - +import requests import streamlit as st @@ -78,25 +77,24 @@ def log_feedback(self, feedback_dict, run_id): "Content-Type": "application/json", } if self.authenticate_request: - headers["Authorization"] = f"Bearer {self.id_token}" + headers["Authorization"] = f"Bearer {self.id_token}" requests.post(url, data=json.dumps(feedback_dict), headers=headers) - def stream_events(self, data: Dict[str, Any]) -> Generator[ - Dict[str, Any], None, None]: + def stream_events( + self, data: Dict[str, Any] + ) -> Generator[Dict[str, Any], None, None]: """Stream events from the server, yielding parsed event data.""" - headers = { - "Content-Type": "application/json", - "Accept": "text/event-stream" - } + headers = {"Content-Type": "application/json", "Accept": "text/event-stream"} if self.authenticate_request: - headers["Authorization"] = f"Bearer {self.id_token}" + headers["Authorization"] = f"Bearer {self.id_token}" - with requests.post(self.url, json={"input": data}, headers=headers, - stream=True) as response: + with requests.post( + self.url, json={"input": data}, headers=headers, stream=True + ) as response: for line in response.iter_lines(): if line: try: - event = json.loads(line.decode('utf-8')) + event = json.loads(line.decode("utf-8")) # print(event) yield event except json.JSONDecodeError: @@ -125,9 +123,6 @@ def new_status(self, status_update: str) -> None: self.tool_expander.markdown(status_update) - - - class EventProcessor: """Processes events from the stream and updates the UI accordingly.""" @@ -144,14 +139,14 @@ def __init__(self, st, client, stream_handler): def process_events(self): """Process events from the stream, handling each event type appropriately.""" - messages = \ - self.st.session_state.user_chats[self.st.session_state['session_id']][ - "messages"] + messages = self.st.session_state.user_chats[ + self.st.session_state["session_id"] + ]["messages"] stream = self.client.stream_events( data={ "messages": messages, - "user_id": self.st.session_state['user_id'], - "session_id": self.st.session_state['session_id'], + "user_id": self.st.session_state["user_id"], + "session_id": self.st.session_state["session_id"], } ) @@ -173,11 +168,13 @@ def process_events(self): def handle_metadata(self, event: Dict[str, Any]) -> None: """Handle metadata events.""" - self.current_run_id = event['data'].get('run_id') + self.current_run_id = event["data"].get("run_id") def handle_tool_start(self, event: Dict[str, Any]) -> None: """Handle the start of a tool or retriever execution.""" - msg = f"\n\nCalling tool: `{event['name']}` with args: `{event['data']['input']}`" + msg = ( + f"\n\nCalling tool: `{event['name']}` with args: `{event['data']['input']}`" + ) self.stream_handler.new_status(msg) def handle_tool_end(self, event: Dict[str, Any]) -> None: @@ -185,24 +182,26 @@ def handle_tool_end(self, event: Dict[str, Any]) -> None: data = event["data"] # Support tool events if isinstance(data["output"], dict): - tool_id = data["output"].get('tool_call_id', None) - tool_name = data["output"].get('name', 'Unknown Tool') + tool_id = data["output"].get("tool_call_id", None) + tool_name = data["output"].get("name", "Unknown Tool") # Support retriever events else: tool_id = event.get("id", "None") tool_name = event.get("name", event["event"]) - tool_input = data['input'] - tool_output = data['output'] + tool_input = data["input"] + tool_output = data["output"] tool_call = {"id": tool_id, "name": tool_name, "args": tool_input} self.tool_calls.append(tool_call) - tool_call_outputs = {"id":tool_id, "output":tool_output} + tool_call_outputs = {"id": tool_id, "output": tool_output} self.tool_calls_outputs.append(tool_call_outputs) - msg = f"\n\nEnding tool: `{tool_call['name']}` with\n **args:**\n" \ - f"```\n{json.dumps(tool_call['args'], indent=2)}\n```\n" \ - f"\n\n**output:**\n " \ - f"```\n{json.dumps(tool_output, indent=2)}\n```" + msg = ( + f"\n\nEnding tool: `{tool_call['name']}` with\n **args:**\n" + f"```\n{json.dumps(tool_call['args'], indent=2)}\n```\n" + f"\n\n**output:**\n " + f"```\n{json.dumps(tool_output, indent=2)}\n```" + ) self.stream_handler.new_status(msg) def handle_chat_model_stream(self, event: Dict[str, Any]) -> None: @@ -211,7 +210,7 @@ def handle_chat_model_stream(self, event: Dict[str, Any]) -> None: content = data["chunk"]["content"] self.additional_kwargs = { **self.additional_kwargs, - **data["chunk"]["additional_kwargs"] + **data["chunk"]["additional_kwargs"], } if content and len(content.strip()) > 0: self.final_content += content @@ -225,12 +224,13 @@ def handle_end(self, event: Dict[str, Any]) -> None: content=self.final_content, tool_calls=self.tool_calls, id=self.current_run_id, - additional_kwargs=additional_kwargs + additional_kwargs=additional_kwargs, ).model_dump() - session = self.st.session_state['session_id'] + session = self.st.session_state["session_id"] self.st.session_state.user_chats[session]["messages"].append(final_message) self.st.session_state.run_id = self.current_run_id + def get_chain_response(st, client, stream_handler): """Process the chain response update the Streamlit UI. diff --git a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/title_summary.py b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/title_summary.py index a34efa51abd..443263c1a4f 100644 --- a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/title_summary.py +++ b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/title_summary.py @@ -63,4 +63,4 @@ MessagesPlaceholder(variable_name="messages"), ]) -chain_title = title_template | llm \ No newline at end of file +chain_title = title_template | llm diff --git a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/utils.py b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/utils.py index 28f40eee878..42cc057a8c5 100644 --- a/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/utils.py +++ b/gemini/sample-apps/conversational-genai-app-template/streamlit/utils/utils.py @@ -20,8 +20,6 @@ SAVED_CHAT_PATH = str(os.getcwd()) + "/.saved_chats" - - def preprocess_text(text): if text[0] == "\n": text = text[1:] @@ -55,6 +53,6 @@ def save_chat(st): file, allow_unicode=True, default_flow_style=False, - encoding='utf-8', + encoding="utf-8", ) st.toast(f"Chat saved to path: ↓ {Path(SAVED_CHAT_PATH) / filename}") diff --git a/gemini/sample-apps/conversational-genai-app-template/tests/integration/patterns/test_langgraph_dummy_agent.py b/gemini/sample-apps/conversational-genai-app-template/tests/integration/patterns/test_langgraph_dummy_agent.py index 25615c9b1fd..0f94b40b7a0 100644 --- a/gemini/sample-apps/conversational-genai-app-template/tests/integration/patterns/test_langgraph_dummy_agent.py +++ b/gemini/sample-apps/conversational-genai-app-template/tests/integration/patterns/test_langgraph_dummy_agent.py @@ -14,10 +14,9 @@ import logging -import pytest -from langchain_core.messages import AIMessageChunk - from app.patterns.langgraph_dummy_agent.chain import chain +from langchain_core.messages import AIMessageChunk +import pytest CHAIN_NAME = "Langgraph agent" @@ -33,19 +32,21 @@ async def test_langgraph_chain_astream_events() -> None: events = [event async for event in chain.astream_events(input_dict, version="v2")] - assert len(events) > 1, f"Expected multiple events for {CHAIN_NAME} chain, " \ - f"got {len(events)}" + assert len(events) > 1, ( + f"Expected multiple events for {CHAIN_NAME} chain, " f"got {len(events)}" + ) on_chain_stream_events = [ event for event in events if event["event"] == "on_chat_model_stream" ] - assert on_chain_stream_events, f"Expected at least one on_chat_model_stream event" \ - f" for {CHAIN_NAME} chain" + assert on_chain_stream_events, ( + f"Expected at least one on_chat_model_stream event" f" for {CHAIN_NAME} chain" + ) for event in on_chain_stream_events: - assert AIMessageChunk.model_validate(event["data"]["chunk"]), ( - f"Invalid AIMessageChunk for {CHAIN_NAME} chain: {event['data']['chunk']}" - ) + assert AIMessageChunk.model_validate( + event["data"]["chunk"] + ), f"Invalid AIMessageChunk for {CHAIN_NAME} chain: {event['data']['chunk']}" logging.info(f"All assertions passed for {CHAIN_NAME} chain") diff --git a/gemini/sample-apps/conversational-genai-app-template/tests/integration/patterns/test_rag_qa.py b/gemini/sample-apps/conversational-genai-app-template/tests/integration/patterns/test_rag_qa.py index 951631b177b..36ace08010b 100644 --- a/gemini/sample-apps/conversational-genai-app-template/tests/integration/patterns/test_rag_qa.py +++ b/gemini/sample-apps/conversational-genai-app-template/tests/integration/patterns/test_rag_qa.py @@ -14,10 +14,9 @@ import logging -import pytest -from langchain_core.messages import AIMessageChunk - from app.patterns.custom_rag_qa.chain import chain +from langchain_core.messages import AIMessageChunk +import pytest CHAIN_NAME = "Rag QA" @@ -33,19 +32,21 @@ async def test_rag_chain_astream_events() -> None: events = [event async for event in chain.astream_events(input_dict, version="v2")] - assert len(events) > 1, f"Expected multiple events for {CHAIN_NAME} chain, " \ - f"got {len(events)}" + assert len(events) > 1, ( + f"Expected multiple events for {CHAIN_NAME} chain, " f"got {len(events)}" + ) on_chain_stream_events = [ event for event in events if event["event"] == "on_chat_model_stream" - ] + ] - assert on_chain_stream_events, f"Expected at least one on_chat_model_stream event" \ - f" for {CHAIN_NAME} chain" + assert on_chain_stream_events, ( + f"Expected at least one on_chat_model_stream event" f" for {CHAIN_NAME} chain" + ) for event in on_chain_stream_events: - assert AIMessageChunk.model_validate(event["data"]["chunk"]), ( - f"Invalid AIMessageChunk for {CHAIN_NAME} chain: {event['data']['chunk']}" - ) + assert AIMessageChunk.model_validate( + event["data"]["chunk"] + ), f"Invalid AIMessageChunk for {CHAIN_NAME} chain: {event['data']['chunk']}" logging.info(f"All assertions passed for {CHAIN_NAME} chain") diff --git a/gemini/sample-apps/conversational-genai-app-template/tests/integration/test_chain.py b/gemini/sample-apps/conversational-genai-app-template/tests/integration/test_chain.py index 9b9da19f584..73c27f93105 100644 --- a/gemini/sample-apps/conversational-genai-app-template/tests/integration/test_chain.py +++ b/gemini/sample-apps/conversational-genai-app-template/tests/integration/test_chain.py @@ -14,10 +14,9 @@ import logging -import pytest -from langchain_core.messages import AIMessageChunk - from app.chain import chain +from langchain_core.messages import AIMessageChunk +import pytest CHAIN_NAME = "Default" @@ -33,19 +32,21 @@ async def test_default_chain_astream_events() -> None: events = [event async for event in chain.astream_events(input_dict, version="v2")] - assert len(events) > 1, f"Expected multiple events for {CHAIN_NAME} chain, " \ - f"got {len(events)}" + assert len(events) > 1, ( + f"Expected multiple events for {CHAIN_NAME} chain, " f"got {len(events)}" + ) on_chain_stream_events = [ event for event in events if event["event"] == "on_chat_model_stream" - ] + ] - assert on_chain_stream_events, f"Expected at least one on_chat_model_stream event" \ - f" for {CHAIN_NAME} chain" + assert on_chain_stream_events, ( + f"Expected at least one on_chat_model_stream event" f" for {CHAIN_NAME} chain" + ) for event in on_chain_stream_events: - assert AIMessageChunk.model_validate(event["data"]["chunk"]), ( - f"Invalid AIMessageChunk for {CHAIN_NAME} chain: {event['data']['chunk']}" - ) + assert AIMessageChunk.model_validate( + event["data"]["chunk"] + ), f"Invalid AIMessageChunk for {CHAIN_NAME} chain: {event['data']['chunk']}" logging.info(f"All assertions passed for {CHAIN_NAME} chain") diff --git a/gemini/sample-apps/conversational-genai-app-template/tests/integration/test_server_e2e.py b/gemini/sample-apps/conversational-genai-app-template/tests/integration/test_server_e2e.py index c10e3c3e4d6..1c76307086f 100644 --- a/gemini/sample-apps/conversational-genai-app-template/tests/integration/test_server_e2e.py +++ b/gemini/sample-apps/conversational-genai-app-template/tests/integration/test_server_e2e.py @@ -18,8 +18,8 @@ import sys import threading import time -import uuid from typing import Any, Iterator +import uuid import pytest import requests @@ -35,11 +35,13 @@ HEADERS = {"Content-Type": "application/json"} + def log_output(pipe: Any, log_func: Any) -> None: """Log the output from the given pipe.""" - for line in iter(pipe.readline, ''): + for line in iter(pipe.readline, ""): log_func(line.strip()) + def start_server() -> subprocess.Popen[str]: """Start the FastAPI server using subprocess and log its output.""" command = [ @@ -58,18 +60,15 @@ def start_server() -> subprocess.Popen[str]: # Start threads to log stdout and stderr in real-time threading.Thread( - target=log_output, - args=(process.stdout, logger.info), - daemon=True + target=log_output, args=(process.stdout, logger.info), daemon=True ).start() threading.Thread( - target=log_output, - args=(process.stderr, logger.error), - daemon=True + target=log_output, args=(process.stderr, logger.error), daemon=True ).start() return process + def wait_for_server(timeout: int = 60, interval: int = 1) -> bool: """Wait for the server to be ready.""" start_time = time.time() @@ -85,6 +84,7 @@ def wait_for_server(timeout: int = 60, interval: int = 1) -> bool: logger.error(f"Server did not become ready within {timeout} seconds") return False + @pytest.fixture(scope="session") def server_fixture(request: Any) -> Iterator[subprocess.Popen[str]]: """Pytest fixture to start and stop the server for testing.""" @@ -103,6 +103,7 @@ def stop_server() -> None: request.addfinalizer(stop_server) yield server_process + def test_chat_stream(server_fixture: subprocess.Popen[str]) -> None: """Test the chat stream functionality.""" logger.info("Starting chat stream test") @@ -112,10 +113,10 @@ def test_chat_stream(server_fixture: subprocess.Popen[str]) -> None: "messages": [ {"role": "user", "content": "Hello, AI!"}, {"role": "ai", "content": "Hello!"}, - {"role": "user", "content": "What cooking recipes do you suggest?"} + {"role": "user", "content": "What cooking recipes do you suggest?"}, ], "user_id": "test-user", - "session_id": "test-session" + "session_id": "test-session", } } @@ -128,28 +129,35 @@ def test_chat_stream(server_fixture: subprocess.Popen[str]) -> None: logger.info(f"Received {len(events)} events") assert len(events) > 2, f"Expected more than 2 events, got {len(events)}." - assert events[0]["event"] == "metadata", f"First event should be 'metadata', " \ - f"got {events[0]['event']}" + assert events[0]["event"] == "metadata", ( + f"First event should be 'metadata', " f"got {events[0]['event']}" + ) assert "run_id" in events[0]["data"], "Missing 'run_id' in metadata" event_types = [event["event"] for event in events] assert "on_chat_model_stream" in event_types, "Missing 'on_chat_model_stream' event" - assert events[-1]["event"] == "end", f"Last event should be 'end', " \ - f"got {events[-1]['event']}" + assert events[-1]["event"] == "end", ( + f"Last event should be 'end', " f"got {events[-1]['event']}" + ) logger.info("Test completed successfully") + def test_chat_stream_error_handling(server_fixture: subprocess.Popen[str]) -> None: """Test the chat stream error handling.""" logger.info("Starting chat stream error handling test") data = {"input": [{"role": "invalid_role", "content": "Cause an error"}]} - response = requests.post(STREAM_EVENTS_URL, headers=HEADERS, json=data, stream=True, timeout=10) + response = requests.post( + STREAM_EVENTS_URL, headers=HEADERS, json=data, stream=True, timeout=10 + ) - assert response.status_code == 422, f"Expected status code 422, " \ - f"got {response.status_code}" + assert response.status_code == 422, ( + f"Expected status code 422, " f"got {response.status_code}" + ) logger.info("Error handling test completed successfully") + def test_collect_feedback(server_fixture: subprocess.Popen[str]) -> None: """ Test the feedback collection endpoint (/feedback) to ensure it properly diff --git a/gemini/sample-apps/conversational-genai-app-template/tests/load_test/load_test.py b/gemini/sample-apps/conversational-genai-app-template/tests/load_test/load_test.py index 0e356953c45..73af9fd5f44 100644 --- a/gemini/sample-apps/conversational-genai-app-template/tests/load_test/load_test.py +++ b/gemini/sample-apps/conversational-genai-app-template/tests/load_test/load_test.py @@ -26,17 +26,17 @@ class ChatStreamUser(HttpUser): def chat_stream(self) -> None: headers = {"Content-Type": "application/json"} if os.environ.get("_ID_TOKEN"): - headers['Authorization'] = f'Bearer {os.environ["_ID_TOKEN"]}' + headers["Authorization"] = f'Bearer {os.environ["_ID_TOKEN"]}' data = { "input": { "messages": [ {"role": "user", "content": "Hello, AI!"}, {"role": "ai", "content": "Hello!"}, - {"role": "user", "content": "Who are you?"} + {"role": "user", "content": "Who are you?"}, ], "user_id": "test-user", - "session_id": "test-session" + "session_id": "test-session", } } @@ -48,7 +48,7 @@ def chat_stream(self) -> None: json=data, catch_response=True, name="/stream_events first event", - stream=True + stream=True, ) as response: if response.status_code == 200: events = [] @@ -73,9 +73,9 @@ def chat_stream(self) -> None: response_time=total_time * 1000, # Convert to milliseconds response_length=len(json.dumps(events)), response=response, - context={} + context={}, ) else: response.failure("Unexpected response structure") else: - response.failure(f"Unexpected status code: {response.status_code}") \ No newline at end of file + response.failure(f"Unexpected status code: {response.status_code}") diff --git a/gemini/sample-apps/conversational-genai-app-template/tests/unit/test_server.py b/gemini/sample-apps/conversational-genai-app-template/tests/unit/test_server.py index f760cf500d6..51565628faf 100644 --- a/gemini/sample-apps/conversational-genai-app-template/tests/unit/test_server.py +++ b/gemini/sample-apps/conversational-genai-app-template/tests/unit/test_server.py @@ -17,13 +17,12 @@ from typing import Any from unittest.mock import patch -import pytest +from app.server import app +from app.utils.input_types import InputChat from fastapi.testclient import TestClient from httpx import AsyncClient from langchain_core.messages import HumanMessage - -from app.server import app -from app.utils.input_types import InputChat +import pytest # Set up logging logging.basicConfig(level=logging.INFO) @@ -41,7 +40,7 @@ def sample_input_chat() -> InputChat: return InputChat( user_id="test-user", session_id="test-session", - messages=[HumanMessage(content="What is the meaning of life?")] + messages=[HumanMessage(content="What is the meaning of life?")], ) @@ -62,7 +61,7 @@ class AsyncIterator: def __init__(self, seq: list) -> None: self.iter = iter(seq) - def __aiter__(self) -> 'AsyncIterator': + def __aiter__(self) -> "AsyncIterator": return self async def __anext__(self) -> Any: @@ -77,7 +76,7 @@ def mock_chain() -> Any: """ Fixture to mock the chain object used in the application. """ - with patch('app.server.chain') as mock: + with patch("app.server.chain") as mock: yield mock @@ -94,8 +93,8 @@ async def test_stream_chat_events(mock_chain: Any) -> None: "messages": [ {"role": "user", "content": "Hello, AI!"}, {"role": "ai", "content": "Hello!"}, - {"role": "user", "content": "What cooking recipes do you suggest?"} - ] + {"role": "user", "content": "What cooking recipes do you suggest?"}, + ], } } @@ -107,8 +106,9 @@ async def test_stream_chat_events(mock_chain: Any) -> None: mock_chain.astream_events.return_value = AsyncIterator(mock_events) - with patch('uuid.uuid4', return_value=mock_uuid), \ - patch('app.server.Traceloop.set_association_properties'): + with patch("uuid.uuid4", return_value=mock_uuid), patch( + "app.server.Traceloop.set_association_properties" + ): async with AsyncClient(app=app, base_url="http://test") as ac: response = await ac.post("/stream_events", json=input_data) diff --git a/gemini/sample-apps/conversational-genai-app-template/tests/unit/test_utils/test_tracing_exporter.py b/gemini/sample-apps/conversational-genai-app-template/tests/unit/test_utils/test_tracing_exporter.py index 73385724109..e584ee0b704 100644 --- a/gemini/sample-apps/conversational-genai-app-template/tests/unit/test_utils/test_tracing_exporter.py +++ b/gemini/sample-apps/conversational-genai-app-template/tests/unit/test_utils/test_tracing_exporter.py @@ -14,41 +14,47 @@ from unittest.mock import Mock, patch -import pytest +from app.utils.tracing import CloudTraceLoggingSpanExporter from google.cloud import logging as gcp_logging from google.cloud import storage from opentelemetry.sdk.trace import ReadableSpan - -from app.utils.tracing import CloudTraceLoggingSpanExporter +import pytest @pytest.fixture def mock_logging_client() -> Mock: return Mock(spec=gcp_logging.Client) + @pytest.fixture def mock_storage_client() -> Mock: return Mock(spec=storage.Client) + @pytest.fixture -def exporter(mock_logging_client: Mock, mock_storage_client: Mock) -> CloudTraceLoggingSpanExporter: +def exporter( + mock_logging_client: Mock, mock_storage_client: Mock +) -> CloudTraceLoggingSpanExporter: return CloudTraceLoggingSpanExporter( project_id="test-project", logging_client=mock_logging_client, storage_client=mock_storage_client, - bucket_name="test-bucket" + bucket_name="test-bucket", ) + def test_init(exporter: CloudTraceLoggingSpanExporter) -> None: assert exporter.project_id == "test-project" assert exporter.bucket_name == "test-bucket" assert exporter.debug is False + def test_ensure_bucket_exists(exporter: CloudTraceLoggingSpanExporter) -> None: exporter.storage_client.bucket.return_value.exists.return_value = False exporter._ensure_bucket_exists() exporter.storage_client.create_bucket.assert_called_once_with("test-bucket") + def test_store_in_gcs(exporter: CloudTraceLoggingSpanExporter) -> None: span_id = "test-span-id" content = "test-content" @@ -56,26 +62,26 @@ def test_store_in_gcs(exporter: CloudTraceLoggingSpanExporter) -> None: assert uri == f"gs://test-bucket/spans/{span_id}.json" exporter.bucket.blob.assert_called_once_with(f"spans/{span_id}.json") -@patch('json.dumps') + +@patch("json.dumps") def test_process_large_attributes_small_payload( - mock_json_dumps: Mock, - exporter: CloudTraceLoggingSpanExporter + mock_json_dumps: Mock, exporter: CloudTraceLoggingSpanExporter ) -> None: - mock_json_dumps.return_value = 'a' * 100 # Small payload + mock_json_dumps.return_value = "a" * 100 # Small payload span_dict = {"attributes": {"key": "value"}} result = exporter._process_large_attributes(span_dict, "span-id") assert result == span_dict -@patch('json.dumps') + +@patch("json.dumps") def test_process_large_attributes_large_payload( - mock_json_dumps: Mock, - exporter: CloudTraceLoggingSpanExporter + mock_json_dumps: Mock, exporter: CloudTraceLoggingSpanExporter ) -> None: - mock_json_dumps.return_value = 'a' * (400 * 1024 + 1) # Large payload + mock_json_dumps.return_value = "a" * (400 * 1024 + 1) # Large payload span_dict = { "attributes": { "key1": "value1", - "traceloop.association.properties.key2": "value2" + "traceloop.association.properties.key2": "value2", } } result = exporter._process_large_attributes(span_dict, "span-id") @@ -84,16 +90,19 @@ def test_process_large_attributes_large_payload( assert "key1" not in result["attributes"] assert "traceloop.association.properties.key2" in result["attributes"] -@patch.object(CloudTraceLoggingSpanExporter, '_process_large_attributes') -def test_export(mock_process_large_attributes: Mock, exporter: CloudTraceLoggingSpanExporter) -> None: + +@patch.object(CloudTraceLoggingSpanExporter, "_process_large_attributes") +def test_export( + mock_process_large_attributes: Mock, exporter: CloudTraceLoggingSpanExporter +) -> None: mock_span = Mock(spec=ReadableSpan) mock_span.get_span_context.return_value.trace_id = 123 mock_span.get_span_context.return_value.span_id = 456 mock_span.to_json.return_value = '{"key": "value"}' - + mock_process_large_attributes.return_value = {"processed": "data"} - + exporter.export([mock_span]) - + mock_process_large_attributes.assert_called_once() exporter.logger.log_struct.assert_called_once()