Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
eliasecchig committed Sep 17, 2024
1 parent 5dccfb3 commit 6178be6
Show file tree
Hide file tree
Showing 10 changed files with 61 additions and 654 deletions.
605 changes: 0 additions & 605 deletions gemini/chat-completions/intro_chat_completions_api.ipynb

This file was deleted.

29 changes: 12 additions & 17 deletions gemini/sample-apps/conversational-genai-app-template/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,26 +68,21 @@ select = [
line-length = 105

[tool.mypy]
# Ensure functions have type annotations
disallow_untyped_calls = true
disallow_untyped_defs = true

# Allow returning Any from functions with no return type specified
allow_untyped_calls = true

# Don't require type annotations for self or cls in methods
disallow_untyped_decorators = false

# Ignore missing imports
disallow_incomplete_defs = true
no_implicit_optional = true
check_untyped_defs = true
disallow_subclassing_any = true
warn_incomplete_stub = true
warn_redundant_casts = true
warn_unused_ignores = true
warn_unreachable = true
follow_imports = "silent"
ignore_missing_imports = true
explicit_package_bases = true
disable_error_code = ["misc", "no-untyped-call", "no-any-return"]

# Don't complain about missing return statements
warn_no_return = false

# Show error codes in error messages
show_error_codes = true

# Ensure all functions have return type annotations
warn_return_any = true


[tool.codespell]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import os
import uuid
from typing import Any

from utils.multimodal_utils import (
HELP_GCS_CHECKBOX,
Expand All @@ -28,10 +29,10 @@


class SideBar:
def __init__(self, st) -> None:
def __init__(self, st: Any) -> None:
self.st = st

def init_side_bar(self):
def init_side_bar(self) -> None:
with self.st.sidebar:
self.url_input_field = self.st.text_input(
label="Service URL",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import os
from datetime import datetime
from typing import Dict

import yaml
from langchain_core.chat_history import BaseChatMessageHistory
Expand All @@ -36,11 +37,11 @@ def __init__(

os.makedirs(self.user_dir, exist_ok=True)

def get_session(self, session_id):
def get_session(self, session_id: str) -> None:
self.session_id = session_id
self.session_file = os.path.join(self.user_dir, f"{session_id}.yaml")

def get_all_conversations(self):
def get_all_conversations(self) -> Dict[str, Dict]:
conversations = {}
for filename in os.listdir(self.user_dir):
if filename.endswith(".yaml"):
Expand All @@ -65,7 +66,7 @@ def get_all_conversations(self):
sorted(conversations.items(), key=lambda x: x[1].get("update_time", ""))
)

def upsert_session(self, session) -> None:
def upsert_session(self, session: Dict) -> None:
session["update_time"] = datetime.now().isoformat()
with open(self.session_file, "w") as f:
yaml.dump(
Expand All @@ -76,7 +77,7 @@ def upsert_session(self, session) -> None:
encoding="utf-8",
)

def set_title(self, session) -> None:
def set_title(self, session: Dict) -> None:
"""
Set the title for the given session.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any


class MessageEditing:
@staticmethod
def edit_message(st, button_idx, message_type):
def edit_message(st: Any, button_idx: int, message_type: str) -> None:
button_id = f"edit_box_{button_idx}"
if message_type == "human":
messages = st.session_state.user_chats[st.session_state["session_id"]][
Expand All @@ -31,7 +33,7 @@ def edit_message(st, button_idx, message_type):
]["content"] = st.session_state[button_id]

@staticmethod
def refresh_message(st, button_idx, content):
def refresh_message(st: Any, button_idx: int, content: str) -> None:
messages = st.session_state.user_chats[st.session_state["session_id"]][
"messages"
]
Expand All @@ -41,7 +43,7 @@ def refresh_message(st, button_idx, content):
st.session_state.modified_prompt = content

@staticmethod
def delete_message(st, button_idx):
def delete_message(st: Any, button_idx: int) -> None:
messages = st.session_state.user_chats[st.session_state["session_id"]][
"messages"
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import base64
from typing import Any, Dict, List, Optional, Union
from urllib.parse import quote

from google.cloud import storage
Expand All @@ -28,7 +29,7 @@
)


def format_content(content):
def format_content(content: Union[str, List[Dict[str, Any]]]) -> str:
if isinstance(content, str):
return content
if len(content) == 1 and content[0]["type"] == "text":
Expand Down Expand Up @@ -82,7 +83,7 @@ def format_content(content):
return markdown


def get_gcs_blob_mime_type(gcs_uri):
def get_gcs_blob_mime_type(gcs_uri: str) -> Optional[str]:
"""Fetches the MIME type (content type) of a Google Cloud Storage blob.
Args:
Expand All @@ -107,7 +108,11 @@ def get_gcs_blob_mime_type(gcs_uri):
return None # Indicate failure


def get_parts_from_files(upload_gcs_checkbox, uploaded_files, gcs_uris):
def get_parts_from_files(
upload_gcs_checkbox: bool,
uploaded_files: List[Any],
gcs_uris: str
) -> List[Dict[str, Any]]:
parts = []
# read from local directly
if not upload_gcs_checkbox:
Expand Down Expand Up @@ -142,7 +147,12 @@ def get_parts_from_files(upload_gcs_checkbox, uploaded_files, gcs_uris):
return parts


def upload_bytes_to_gcs(bucket_name, blob_name, file_bytes, content_type=None):
def upload_bytes_to_gcs(
bucket_name: str,
blob_name: str,
file_bytes: bytes,
content_type: Optional[str] = None
) -> str:
"""Uploads a bytes object to Google Cloud Storage and returns the GCS URI.
Args:
Expand All @@ -167,7 +177,7 @@ def upload_bytes_to_gcs(bucket_name, blob_name, file_bytes, content_type=None):
return gcs_uri


def gs_uri_to_https_url(gs_uri):
def gs_uri_to_https_url(gs_uri: str) -> str:
"""Converts a GS URI to an HTTPS URL without authentication.
Args:
Expand All @@ -191,7 +201,7 @@ def gs_uri_to_https_url(gs_uri):
return https_url


def upload_files_to_gcs(st, bucket_name, files_to_upload):
def upload_files_to_gcs(st: Any, bucket_name: str, files_to_upload: List[Any]) -> None:
bucket_name = bucket_name.replace("gs://", "")
uploaded_uris = []
for file in files_to_upload:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import json
from typing import Any, Dict, Generator
from typing import Any, Dict, Generator, List, Optional
from urllib.parse import urljoin

import google.auth
Expand All @@ -39,7 +39,7 @@ def __init__(self, url: str, authenticate_request: bool = False) -> None:
if self.authenticate_request:
self.id_token = self.get_id_token(self.url)

def get_id_token(self, url: str):
def get_id_token(self, url: str) -> str:
"""
Retrieves an ID token, attempting to use a service-to-service method first and
otherwise using user default credentials.
Expand All @@ -57,7 +57,7 @@ def get_id_token(self, url: str):
token = self.creds.id_token
return token

def log_feedback(self, feedback_dict, run_id):
def log_feedback(self, feedback_dict: Dict[str, Any], run_id: str) -> None:
score = feedback_dict["score"]
if score == "😞":
score = 0.0
Expand Down Expand Up @@ -105,7 +105,7 @@ def stream_events(
class StreamHandler:
"""Handles streaming updates to a Streamlit interface."""

def __init__(self, st, initial_text=""):
def __init__(self, st: Any, initial_text: str = "") -> None:
"""Initialize the StreamHandler with Streamlit context and initial text."""
self.st = st
self.tool_expander = st.expander("Tool Calls:", expanded=False)
Expand All @@ -127,18 +127,18 @@ def new_status(self, status_update: str) -> None:
class EventProcessor:
"""Processes events from the stream and updates the UI accordingly."""

def __init__(self, st, client, stream_handler):
def __init__(self, st: Any, client: Client, stream_handler: StreamHandler) -> None:
"""Initialize the EventProcessor with Streamlit context, client, and stream handler."""
self.st = st
self.client = client
self.stream_handler = stream_handler
self.final_content = ""
self.tool_calls = []
self.tool_calls_outputs = []
self.additional_kwargs = {}
self.current_run_id = None
self.tool_calls: List[Dict[str, Any]] = []
self.tool_calls_outputs: List[Dict[str, Any]] = []
self.additional_kwargs: Dict[str, Any] = {}
self.current_run_id: Optional[str] = None

def process_events(self):
def process_events(self) -> None:
"""Process events from the stream, handling each event type appropriately."""
messages = self.st.session_state.user_chats[
self.st.session_state["session_id"]
Expand All @@ -162,7 +162,7 @@ def process_events(self):
}

for event in stream:
event_type = event.get("event")
event_type = str(event.get("event"))
handler = event_handlers.get(event_type)
if handler:
handler(event)
Expand Down Expand Up @@ -232,15 +232,15 @@ def handle_end(self, event: Dict[str, Any]) -> None:
self.st.session_state.run_id = self.current_run_id


def get_chain_response(st, client, stream_handler):
def get_chain_response(st: Any, client: Client, stream_handler: StreamHandler) -> None:
"""Process the chain response update the Streamlit UI.
This function initiates the event processing for a chain of operations,
involving an AI model's response generation and potential tool calls.
It creates an EventProcessor instance and starts the event processing loop.
Args:
st (streamlit): The Streamlit app instance, used for accessing session state
st (Any): The Streamlit app instance, used for accessing session state
and updating the UI.
client (Client): An instance of the Client class used to stream events
from the server.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,24 @@

import os
from pathlib import Path
from typing import Any, Dict, List, Union

import yaml

SAVED_CHAT_PATH = str(os.getcwd()) + "/.saved_chats"


def preprocess_text(text):
def preprocess_text(text: str) -> str:
if text[0] == "\n":
text = text[1:]
if text[-1] == "\n":
text = text[:-1]
return text


def fix_messages(messages):
def fix_messages(
messages: List[Dict[str, Union[str, List[Dict[str, str]]]]]
) -> List[Dict[str, Union[str, List[Dict[str, str]]]]]:
for message in messages:
if isinstance(message["content"], list):
for part in message["content"]:
Expand All @@ -39,7 +42,7 @@ def fix_messages(messages):
return messages


def save_chat(st):
def save_chat(st: Any) -> None:
Path(SAVED_CHAT_PATH).mkdir(parents=True, exist_ok=True)
session_id = st.session_state["session_id"]
session = st.session_state.user_chats[session_id]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
import time

from locust import HttpUser, between, task # type: ignore[import-not-found]
from locust import HttpUser, between, task


class ChatStreamUser(HttpUser):
Expand Down
Binary file added qemu_hadolint_20240917-164918_24475.core
Binary file not shown.

0 comments on commit 6178be6

Please sign in to comment.