Skip to content

Commit

Permalink
Resolve conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
akotyla committed Oct 7, 2024
2 parents e494ef9 + a186751 commit 843311f
Show file tree
Hide file tree
Showing 41 changed files with 1,031 additions and 876 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ repos:
- id: mypy
# You can add additional plugins for mypy below
# such as types-python-dateutil
additional_dependencies: [pydantic>=2.8.2]
additional_dependencies: [pydantic>=2.8.2, types-pyyaml>=6.0.12]
exclude: (/test_|setup.py|/tests/|docs/)

# Sort imports alphabetically, and automatically separated into sections and by type.
Expand Down
2 changes: 1 addition & 1 deletion packages/ragbits-cli/src/ragbits/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import ragbits

app = Typer()
app = Typer(no_args_is_help=True)


def main() -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import chromadb

from ragbits.core.embeddings.litellm import LiteLLMEmbeddings
from ragbits.core.vector_store.chromadb_store import ChromaDBStore
from ragbits.document_search import DocumentSearch
from ragbits.document_search.documents.document import DocumentMeta
from ragbits.document_search.vector_store.chromadb_store import ChromaDBStore

documents = [
DocumentMeta.create_text_document_from_literal("RIP boiled water. You will be mist."),
Expand Down
4 changes: 2 additions & 2 deletions packages/ragbits-core/examples/prompt_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ class LoremPrompt(Prompt[LoremPromptInput, LoremPromptOutput]):


if __name__ == "__main__":
lorem_prompt = LoremPrompt(LoremPromptInput(theme="business"))
lorem_prompt.add_assistant_message("Lorem Ipsum biznessum dolor copy machinum yearly reportum")
lorem_prompt = LoremPrompt(LoremPromptInput(theme="animals"))
lorem_prompt.add_few_shot("theme: business", "Lorem Ipsum biznessum dolor copy machinum yearly reportum")
print("CHAT:")
print(lorem_prompt.chat)
print()
Expand Down
3 changes: 3 additions & 0 deletions packages/ragbits-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ dependencies = [
]

[project.optional-dependencies]
chromadb = [
"chromadb~=0.4.24",
]
litellm = [
"litellm~=1.46.0",
]
Expand Down
19 changes: 0 additions & 19 deletions packages/ragbits-core/src/ragbits/core/cli.py

This file was deleted.

65 changes: 40 additions & 25 deletions packages/ragbits-core/src/ragbits/core/prompt/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ class Prompt(Generic[InputT, OutputT], BasePromptWithParser[OutputT], metaclass=

system_prompt: Optional[str] = None
user_prompt: str
additional_messages: ChatFormat = []

# Additional messages to be added to the conversation after the system prompt
few_shots: ChatFormat = []

# function that parses the response from the LLM to specific output type
# if not provided, the class tries to set it automatically based on the output type
Expand Down Expand Up @@ -111,10 +113,14 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
if self.input_type and input_data is None:
raise ValueError("Input data must be provided")

self.system_message = (
self.rendered_system_prompt = (
self._render_template(self.system_prompt_template, input_data) if self.system_prompt_template else None
)
self.user_message = self._render_template(self.user_prompt_template, input_data)
self.rendered_user_prompt = self._render_template(self.user_prompt_template, input_data)

# Additional few shot examples that can be added dynamically using methods
# (in opposite to the static `few_shots` attribute which is defined in the class)
self._instace_few_shots: ChatFormat = []
super().__init__()

@property
Expand All @@ -125,35 +131,31 @@ def chat(self) -> ChatFormat:
Returns:
ChatFormat: A list of dictionaries, each containing the role and content of a message.
"""
return [
*([{"role": "system", "content": self.system_message}] if self.system_message is not None else []),
{"role": "user", "content": self.user_message},
] + self.additional_messages

def add_user_message(self, message: str) -> "Prompt[InputT, OutputT]":
"""
Add a message from the user to the conversation.
Args:
message (str): The message to add.
Returns:
Prompt[InputT, OutputT]: The current prompt instance in order to allow chaining.
"""
self.additional_messages.append({"role": "user", "content": message})
return self

def add_assistant_message(self, message: str) -> "Prompt[InputT, OutputT]":
chat = [
*(
[{"role": "system", "content": self.rendered_system_prompt}]
if self.rendered_system_prompt is not None
else []
),
*self.few_shots,
*self._instace_few_shots,
{"role": "user", "content": self.rendered_user_prompt},
]
return chat

def add_few_shot(self, user_message: str, assistant_message: str) -> "Prompt[InputT, OutputT]":
"""
Add a message from the assistant to the conversation.
Add a few-shot example to the conversation.
Args:
message (str): The message to add.
user_message (str): The message from the user.
assistant_message (str): The message from the assistant.
Returns:
Prompt[InputT, OutputT]: The current prompt instance in order to allow chaining.
"""
self.additional_messages.append({"role": "assistant", "content": message})
self._instace_few_shots.append({"role": "user", "content": user_message})
self._instace_few_shots.append({"role": "assistant", "content": assistant_message})
return self

def output_schema(self) -> Optional[Dict | Type[BaseModel]]:
Expand Down Expand Up @@ -190,3 +192,16 @@ def parse_response(self, response: str) -> OutputT:
ResponseParsingError: If the response cannot be parsed.
"""
return self.response_parser(response)

@classmethod
def to_promptfoo(cls, config: dict[str, Any]) -> ChatFormat:
"""
Generate a prompt in the promptfoo format from a promptfoo test configuration.
Args:
config: The promptfoo test configuration.
Returns:
ChatFormat: The prompt in the format used by promptfoo.
"""
return cls(cls.input_type.model_validate(config["vars"])).chat # type: ignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from ..utils import get_cls_from_config
from .base import VectorStore
from .chromadb_store import ChromaDBStore
# from .chromadb_store import ChromaDBStore
from .in_memory import InMemoryVectorStore

__all__ = ["InMemoryVectorStore", "VectorStore", "ChromaDBStore"]
__all__ = ["InMemoryVectorStore", "VectorStore"] # , "ChromaDBStore"]

module = sys.modules[__name__]

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
from copy import deepcopy
from hashlib import sha256
from typing import List, Literal, Optional, Union

Expand All @@ -11,8 +10,8 @@
HAS_CHROMADB = False

from ragbits.core.embeddings.base import Embeddings
from ragbits.document_search.vector_store.base import VectorStore
from ragbits.document_search.vector_store.in_memory import VectorDBEntry
from ragbits.core.vector_store.base import VectorStore
from ragbits.core.vector_store.in_memory import VectorDBEntry


class ChromaDBStore(VectorStore):
Expand Down Expand Up @@ -79,48 +78,16 @@ def _return_best_match(self, retrieved: dict) -> Optional[str]:

return None

def _process_db_entry(self, entry: VectorDBEntry) -> tuple[str, list[float], str, dict]:
def _process_db_entry(self, entry: VectorDBEntry) -> tuple[str, list[float], dict]:
doc_id = sha256(entry.key.encode("utf-8")).hexdigest()
embedding = entry.vector
text = entry.metadata["content"]

metadata = deepcopy(entry.metadata)
metadata["document"]["source"]["path"] = str(metadata["document"]["source"]["path"])
metadata["key"] = entry.key
metadata = {key: json.dumps(val) if isinstance(val, dict) else val for key, val in metadata.items()}
metadata = {
"__key": entry.key,
"__metadata": json.dumps(entry.metadata, default=str),
}

return doc_id, embedding, text, metadata

def _process_metadata(self, metadata: dict) -> dict[str, Union[str, int, float, bool]]:
"""
Processes the metadata dictionary by parsing JSON strings if applicable.
Args:
metadata: A dictionary containing metadata where values may be JSON strings.
Returns:
A dictionary with the same keys as the input, where JSON strings are parsed
into their respective Python data types.
"""
return {key: json.loads(val) if self._is_json(val) else val for key, val in metadata.items()}

def _is_json(self, myjson: str) -> bool:
"""
Check if the provided string is a valid JSON.
Args:
myjson: The string to be checked.
Returns:
True if the string is a valid JSON, False otherwise.
"""
try:
if isinstance(myjson, str):
json.loads(myjson)
return True
return False
except ValueError:
return False
return doc_id, embedding, metadata

@property
def embedding_function(self) -> Union[Embeddings, chromadb.EmbeddingFunction]:
Expand All @@ -139,12 +106,10 @@ async def store(self, entries: List[VectorDBEntry]) -> None:
Args:
entries: The entries to store.
"""
collection = self._get_chroma_collection()

entries_processed = list(map(self._process_db_entry, entries))
ids, embeddings, texts, metadatas = map(list, zip(*entries_processed))
ids, embeddings, metadatas = map(list, zip(*entries_processed))

collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas)
self._collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas)

async def retrieve(self, vector: List[float], k: int = 5) -> List[VectorDBEntry]:
"""
Expand All @@ -157,43 +122,20 @@ async def retrieve(self, vector: List[float], k: int = 5) -> List[VectorDBEntry]
Returns:
The retrieved entries.
"""
collection = self._get_chroma_collection()
query_result = collection.query(query_embeddings=[vector], n_results=k)
query_result = self._collection.query(query_embeddings=[vector], n_results=k)

db_entries = []
for meta in query_result.get("metadatas"):
db_entry = VectorDBEntry(
key=meta[0].get("key"),
key=meta[0]["__key"],
vector=vector,
metadata=self._process_metadata(meta[0]),
metadata=json.loads(meta[0]["__metadata"]),
)

db_entries.append(db_entry)

return db_entries

async def find_similar(self, text: str) -> Optional[str]:
"""
Finds the most similar text in the chroma collection or returns None if the most similar text
has distance bigger than `self.max_distance`.
Args:
text: The text to find similar to.
Returns:
The most similar text or None if no similar text is found.
"""

collection = self._get_chroma_collection()

if isinstance(self._embedding_function, Embeddings):
embedding = await self._embedding_function.embed_text([text])
retrieved = collection.query(query_embeddings=embedding, n_results=1)
else:
retrieved = collection.query(query_texts=[text], n_results=1)

return self._return_best_match(retrieved)

def __repr__(self) -> str:
"""
Returns the string representation of the object.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from ragbits.document_search.vector_store.base import VectorDBEntry, VectorStore
from ragbits.core.vector_store.base import VectorDBEntry, VectorStore


class InMemoryVectorStore(VectorStore):
Expand Down
Loading

0 comments on commit 843311f

Please sign in to comment.