Skip to content

Commit

Permalink
Aligning to micpst's comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
PatrykWyzgowski committed Oct 14, 2024
1 parent c98bda1 commit 57ef338
Show file tree
Hide file tree
Showing 20 changed files with 114 additions and 105 deletions.
3 changes: 2 additions & 1 deletion packages/ragbits-cli/src/ragbits/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@


def main() -> None:
"""Main entry point for the CLI.
"""
Main entry point for the CLI.
This function registers all the CLI modules in the ragbits packages:
- iterates over every package in the ragbits.* namespace
Expand Down
12 changes: 9 additions & 3 deletions packages/ragbits-core/examples/llm_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,27 @@


class LoremPromptInput(BaseModel):
"""Input format for the LoremPrompt."""
"""
Input format for the LoremPrompt.
"""

theme: str
pun_allowed: bool = False


class LoremPromptOutput(BaseModel):
"""Output format for the LoremPrompt."""
"""
Output format for the LoremPrompt.
"""

joke: str
joke_category: str


class JokePrompt(Prompt[LoremPromptInput, LoremPromptOutput]):
"""A prompt that generates jokes."""
"""
A prompt that generates jokes.
"""

system_prompt = """
You are a joke generator. The jokes you generate should be funny and not offensive.
Expand Down
3 changes: 2 additions & 1 deletion packages/ragbits-core/src/ragbits/core/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@


def register(app: typer.Typer) -> None:
"""Register the CLI commands for the package.
"""
Register the CLI commands for the package.
Args:
app: The Typer object to register the commands with.
Expand Down
3 changes: 2 additions & 1 deletion packages/ragbits-core/src/ragbits/core/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ class Embeddings(ABC):

@abstractmethod
async def embed_text(self, data: list[str]) -> list[list[float]]:
"""Creates embeddings for the given strings.
"""
Creates embeddings for the given strings.
Args:
data: List of strings to get embeddings for.
Expand Down
12 changes: 8 additions & 4 deletions packages/ragbits-core/src/ragbits/core/llms/clients/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

@dataclass
class LiteLLMOptions(LLMOptions):
"""Dataclass that represents all available LLM call options for the LiteLLM client.
"""
Dataclass that represents all available LLM call options for the LiteLLM client.
Each of them is described in the [LiteLLM documentation](https://docs.litellm.ai/docs/completion/input).
"""

Expand All @@ -35,7 +36,8 @@ class LiteLLMOptions(LLMOptions):


class LiteLLMClient(LLMClient[LiteLLMOptions]):
"""Client for the LiteLLM that supports calls to 100+ LLMs APIs, including OpenAI, Anthropic, VertexAI,
"""
Client for the LiteLLM that supports calls to 100+ LLMs APIs, including OpenAI, Anthropic, VertexAI,
Hugging Face and others.
"""

Expand All @@ -50,7 +52,8 @@ def __init__(
api_version: str | None = None,
use_structured_output: bool = False,
) -> None:
"""Constructs a new LiteLLMClient instance.
"""
Constructs a new LiteLLMClient instance.
Args:
model_name: Name of the model to use.
Expand Down Expand Up @@ -78,7 +81,8 @@ async def call(
json_mode: bool = False,
output_schema: type[BaseModel] | dict | None = None,
) -> str:
"""Calls the appropriate LLM endpoint with the given prompt and options.
"""
Calls the appropriate LLM endpoint with the given prompt and options.
Args:
conversation: List of dicts with "role" and "content" keys, representing the chat history so far.
Expand Down
6 changes: 4 additions & 2 deletions packages/ragbits-core/src/ragbits/core/llms/clients/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def __init__(
*,
hf_api_key: str | None = None,
) -> None:
"""Constructs a new local LLMClient instance.
"""
Constructs a new local LLMClient instance.
Args:
model_name: Name of the model to use.
Expand All @@ -74,7 +75,8 @@ async def call(
json_mode: bool = False,
output_schema: type[BaseModel] | dict | None = None,
) -> str:
"""Makes a call to the local LLM with the provided prompt and options.
"""
Makes a call to the local LLM with the provided prompt and options.
Args:
conversation: List of dicts with "role" and "content" keys, representing the chat history so far.
Expand Down
6 changes: 4 additions & 2 deletions packages/ragbits-core/src/ragbits/core/llms/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def __init__(
api_version: str | None = None,
use_structured_output: bool = False,
) -> None:
"""Constructs a new LiteLLM instance.
"""
Constructs a new LiteLLM instance.
Args:
model_name: Name of the [LiteLLM supported model](https://docs.litellm.ai/docs/providers) to be used.\
Expand Down Expand Up @@ -71,7 +72,8 @@ def client(self) -> LiteLLMClient:
)

def count_tokens(self, prompt: BasePrompt) -> int:
"""Counts tokens in the prompt.
"""
Counts tokens in the prompt.
Args:
prompt: Formatted prompt template with conversation and response parsing configuration.
Expand Down
9 changes: 6 additions & 3 deletions packages/ragbits-core/src/ragbits/core/llms/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def __init__(
*,
api_key: str | None = None,
) -> None:
"""Constructs a new local LLM instance.
"""
Constructs a new local LLM instance.
Args:
model_name: Name of the model to use. This should be a model from the CausalLM class.
Expand All @@ -46,15 +47,17 @@ def __init__(

@cached_property
def client(self) -> LocalLLMClient:
"""Client for the LLM.
"""
Client for the LLM.
Returns:
The client used to interact with the LLM.
"""
return LocalLLMClient(model_name=self.model_name, hf_api_key=self.api_key)

def count_tokens(self, prompt: BasePrompt) -> int:
"""Counts tokens in the messages.
"""
Counts tokens in the messages.
Args:
prompt: Messages to count tokens for.
Expand Down
3 changes: 2 additions & 1 deletion packages/ragbits-core/src/ragbits/core/llms/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

# Sentinel class used until PEP 0661 is accepted
class NotGiven:
"""A sentinel singleton class used to distinguish omitted keyword arguments
"""
A sentinel singleton class used to distinguish omitted keyword arguments
from those passed in with the value None (which may have different behavior).
For example:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@


class PromptDiscovery:
"""Discovers Prompt objects within Python modules.
"""
Discovers Prompt objects within Python modules.
Args:
file_pattern (str): The file pattern to search for Prompt objects. Defaults to "**/prompt_*.py"
Expand All @@ -22,7 +23,8 @@ def __init__(self, file_pattern: str = core_config.prompt_path_pattern, root_pat

@staticmethod
def is_prompt_subclass(obj: Any) -> bool: # noqa: ANN401
"""Checks if an object is a class that is a subclass of Prompt (but not Prompt itself).
"""
Checks if an object is a class that is a subclass of Prompt (but not Prompt itself).
Args:
obj (any): The object to check.
Expand All @@ -35,7 +37,8 @@ def is_prompt_subclass(obj: Any) -> bool: # noqa: ANN401
return inspect.isclass(obj) and not get_origin(obj) and issubclass(obj, Prompt) and obj != Prompt

def discover(self) -> set[type[Prompt]]:
"""Discovers Prompt objects within the specified file paths.
"""
Discovers Prompt objects within the specified file paths.
Returns:
set[Prompt]: The discovered Prompt objects.
Expand Down
4 changes: 1 addition & 3 deletions packages/ragbits-core/src/ragbits/core/prompt/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,7 @@ def bool_parser(value: str) -> bool:
raise ResponseParsingError(f"Could not parse '{value}' as a boolean")


def build_pydantic_parser(
model: type[PydanticModelT],
) -> Callable[[str], PydanticModelT]:
def build_pydantic_parser(model: type[PydanticModelT]) -> Callable[[str], PydanticModelT]:
"""
Builds a parser for a specific Pydantic model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ def __init__(self) -> None:
self._storage: dict[str, VectorDBEntry] = {}

async def store(self, entries: list[VectorDBEntry]) -> None:
"""Store entries in the vector store.
"""
Store entries in the vector store.
Args:
entries: The entries to store.
Expand All @@ -21,7 +22,8 @@ async def store(self, entries: list[VectorDBEntry]) -> None:
self._storage[entry.key] = entry

async def retrieve(self, vector: list[float], k: int = 5) -> list[VectorDBEntry]:
"""Retrieve entries from the vector store.
"""
Retrieve entries from the vector store.
Args:
vector: The vector to search for.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@


class DocumentType(str, Enum):
"""Types of documents that can be stored."""
"""
Types of documents that can be stored.
"""

MD = "md"
TXT = "txt"
Expand Down Expand Up @@ -50,7 +52,8 @@ def id(self) -> str:
return self.source.get_id()

async def fetch(self) -> "Document":
"""This method fetches the document from source (potentially remote) and creates an object to interface with it.
"""
This method fetches the document from source (potentially remote) and creates an object to interface with it.
Based on the document type, it will return a different object.
Returns:
Expand All @@ -61,7 +64,8 @@ async def fetch(self) -> "Document":

@classmethod
def create_text_document_from_literal(cls, content: str) -> "DocumentMeta":
"""Create a text document from a literal content.
"""
Create a text document from a literal content.
Args:
content: The content of the document.
Expand All @@ -79,7 +83,8 @@ def create_text_document_from_literal(cls, content: str) -> "DocumentMeta":

@classmethod
def from_local_path(cls, local_path: Path) -> "DocumentMeta":
"""Create a document metadata from a local path.
"""
Create a document metadata from a local path.
Args:
local_path: The local path to the document.
Expand All @@ -94,14 +99,17 @@ def from_local_path(cls, local_path: Path) -> "DocumentMeta":


class Document(BaseModel):
"""An object representing a document which is downloaded and stored locally."""
"""
An object representing a document which is downloaded and stored locally.
"""

local_path: Path
metadata: DocumentMeta

@classmethod
def from_document_meta(cls, document_meta: DocumentMeta, local_path: Path) -> "Document":
"""Create a document from a document metadata.
"""
Create a document from a document metadata.
Based on the document type, it will return a different object.
Args:
Expand All @@ -117,11 +125,14 @@ def from_document_meta(cls, document_meta: DocumentMeta, local_path: Path) -> "D


class TextDocument(Document):
"""An object representing a text document."""
"""
An object representing a text document.
"""

@property
def content(self) -> str:
"""Get the content of the document.
"""
Get the content of the document.
Returns:
The content of the document.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class Element(BaseModel, ABC):

@abstractmethod
def get_key(self) -> str:
"""Get the key of the element which will be used to generate the vector.
"""
Get the key of the element which will be used to generate the vector.
Returns:
The key.
Expand All @@ -36,7 +37,8 @@ def __pydantic_init_subclass__(cls, **kwargs: Any) -> None: # pylint: disable=u

@classmethod
def from_vector_db_entry(cls, db_entry: VectorDBEntry) -> "Element":
"""Create an element from a vector database entry.
"""
Create an element from a vector database entry.
Args:
db_entry: The vector database entry.
Expand All @@ -51,7 +53,8 @@ def from_vector_db_entry(cls, db_entry: VectorDBEntry) -> "Element":
return element_cls(**meta)

def to_vector_db_entry(self, vector: list[float]) -> VectorDBEntry:
"""Create a vector database entry from the element.
"""
Create a vector database entry from the element.
Args:
vector: The vector.
Expand All @@ -67,13 +70,16 @@ def to_vector_db_entry(self, vector: list[float]) -> VectorDBEntry:


class TextElement(Element):
"""An object representing a text element in a document."""
"""
An object representing a text element in a document.
"""

element_type: str = "text"
content: str

def get_key(self) -> str:
"""Get the key of the element which will be used to generate the vector.
"""
Get the key of the element which will be used to generate the vector.
Returns:
The key.
Expand Down
Loading

0 comments on commit 57ef338

Please sign in to comment.