Skip to content

Commit

Permalink
Pylint - docstrings and cosmetics
Browse files Browse the repository at this point in the history
Signed-off-by: Ygal Blum <[email protected]>
  • Loading branch information
ygalblum committed Mar 24, 2024
1 parent afc82b2 commit ccf2a41
Show file tree
Hide file tree
Showing 23 changed files with 164 additions and 78 deletions.
1 change: 1 addition & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ disable=
; duplicate-string-formatting-argument, # TMP: will be fixed in close future
consider-using-f-string, # sorry, not gonna happen, still have to support py2
; use-dict-literal
cyclic-import,

[FORMAT]
# Maximum number of characters on a single line.
Expand Down
Empty file.
Empty file.
1 change: 1 addition & 0 deletions knowledge_base_gpt/apps/ingest/__main__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
""" Content ingestion application """
from knowledge_base_gpt.libs.injector.di import global_injector
from knowledge_base_gpt.apps.ingest.ingest import Ingestor

Expand Down
14 changes: 8 additions & 6 deletions knowledge_base_gpt/apps/ingest/ingest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os
""" Content ingest """
from typing import List

from injector import inject, singleton
Expand All @@ -11,7 +11,8 @@


@singleton
class Ingestor():
class Ingestor(): # pylint:disable=R0903
""" Content ingest """

@inject
def __init__(self, settings: Settings, loader: Loader, vector_store: VectorStore) -> None:
Expand All @@ -20,11 +21,11 @@ def __init__(self, settings: Settings, loader: Loader, vector_store: VectorStore
self._chunk_overlap = settings.text_splitter.chunk_overlap
self._vector_store = vector_store

def _process_documents(self, ignored_files: List[str] = []) -> List[Document]:
def _process_documents(self, ignored_files: List[str]) -> List[Document]:
"""
Load documents and split in chunks
"""
print(f"Loading documents")
print("Loading documents")
documents = self._loader.load_documents(ignored_files)
if not documents:
return []
Expand All @@ -36,12 +37,13 @@ def _process_documents(self, ignored_files: List[str] = []) -> List[Document]:
return documents

def run(self):
""" Ingest the documents into the vector store based on the settings """
collection = self._vector_store.db.get()
documents = self._process_documents(list(set(metadata['source'] for metadata in collection['metadatas'])))
if len(documents) == 0:
print("No new documents to load")
else:
print(f"Creating embeddings. May take some minutes...")
print("Creating embeddings. May take some minutes...")
self._vector_store.db.add_documents(documents)
self._vector_store.db.persist()
print(f"Ingestion complete! You can now run privateGPT.py to query your documents")
print("Ingestion complete! You can now run privateGPT.py to query your documents")
Empty file.
1 change: 1 addition & 0 deletions knowledge_base_gpt/apps/slackbot/__main__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
""" Slackbot application """
from knowledge_base_gpt.libs.injector.di import global_injector
from knowledge_base_gpt.apps.slackbot.slack_bot import KnowledgeBaseSlackBot

Expand Down
26 changes: 18 additions & 8 deletions knowledge_base_gpt/apps/slackbot/slack_bot.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
""" Slackbot application backend """
from injector import inject, singleton
from slack_bolt import App
from slack_bolt.adapter.socket_mode import SocketModeHandler
Expand All @@ -9,11 +10,12 @@


class KnowledgeBaseSlackBotException(Exception):
pass
""" Wrapper for SlackBot specific exception """


@singleton
class KnowledgeBaseSlackBot():
class KnowledgeBaseSlackBot(): # pylint:disable=R0903
""" Slackbot application backend """

@inject
def __init__(self, settings: Settings, private_chat: PrivateChat, history: HistoryRedis) -> None:
Expand All @@ -28,18 +30,19 @@ def __init__(self, settings: Settings, private_chat: PrivateChat, history: Histo
self._handler.app.command('/conversation_forward')(self._forward_question)

def run(self):
""" Start the Slackbot backend application """
self._handler.start()

def _get_forward_question_channel_id(self):
if self._forward_question_channel_name is None:
raise KnowledgeBaseSlackBotException(f"Slackbot forward channel name was not set")
raise KnowledgeBaseSlackBotException("Slackbot forward channel name was not set")
try:
for result in self._handler.app.client.conversations_list():
for channel in result["channels"]:
if channel["name"] == self._forward_question_channel_name:
return channel["id"]
except SlackApiError as e:
raise KnowledgeBaseSlackBotException(e)
raise KnowledgeBaseSlackBotException(e) from e
raise KnowledgeBaseSlackBotException(f"The channel {self._forward_question_channel_name} does not exits")

def _got_message(self, message, say):
Expand All @@ -48,7 +51,11 @@ def _got_message(self, message, say):
user=message['user'],
text="On it. Be back with your answer soon"
)
answer = self._private_chat.answer_query(self._history.get_messages(message['user']), message['text'], chat_identifier=message['user'])
answer = self._private_chat.answer_query(
self._history.get_messages(message['user']),
message['text'],
chat_identifier=message['user']
)
self._history.add_to_history(message['user'], answer)
say(answer['answer'])

Expand All @@ -63,7 +70,7 @@ def _is_direct_message_channel(self, command):
return False


def _reset_conversation(self, ack, say, command):
def _reset_conversation(self, ack, say, command): # pylint:disable=unused-argument
ack()
if not self._is_direct_message_channel(command):
return
Expand All @@ -85,7 +92,7 @@ def _messages_to_text(messages):
text += '\n'
return text

def _forward_question(self, ack, say, command):
def _forward_question(self, ack, say, command): # pylint:disable=unused-argument
ack()
if not self._is_direct_message_channel(command):
return
Expand All @@ -94,7 +101,10 @@ def _forward_question(self, ack, say, command):
if len(messages) == 0:
msg = 'There is no active conversation'
else:
self._handler.app.client.chat_postMessage(channel=self._forward_question_channel_id, text=self._messages_to_text(messages))
self._handler.app.client.chat_postMessage(
channel=self._forward_question_channel_id,
text=self._messages_to_text(messages)
)
msg = f'The conversation was forwarded to {self._forward_question_channel_name}'

self._handler.app.client.chat_postEphemeral(
Expand Down
1 change: 1 addition & 0 deletions knowledge_base_gpt/libs/common/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
""" Hold application constants """
from pathlib import Path


Expand Down
11 changes: 7 additions & 4 deletions knowledge_base_gpt/libs/embedding/embedding.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
""" Create and abstract the embedding """
from injector import inject, singleton
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.embeddings import OllamaEmbeddings
from langchain_core.embeddings import Embeddings as LangChainEmbeddings
from langchain_community.embeddings import HuggingFaceEmbeddings, OllamaEmbeddings

from knowledge_base_gpt.libs.settings.settings import Settings


@singleton
class Embedding():
class Embedding(): # pylint:disable=R0903
""" Create and abstract the embedding """
@inject
def __init__(self, settings: Settings) -> None:
mode = settings.embedding.mode
Expand All @@ -31,5 +33,6 @@ def __init__(self, settings: Settings) -> None:
pass

@property
def embeddings(self):
def embeddings(self) -> LangChainEmbeddings:
""" Return the embedding implementation """
return self._embeddings
25 changes: 22 additions & 3 deletions knowledge_base_gpt/libs/gpt/ollama_info.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
"""
Keep track of metrics provided by the Ollama API
"""

from contextlib import contextmanager
from contextvars import ContextVar
import threading
from typing import Any, Dict, List, Optional, Generator
from uuid import UUID

from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.outputs import LLMResult
from langchain_core.tracers.context import register_configure_hook


class OllamaMetrics():
""" Metrics of a single Ollama request """
prompt_eval_count: int = 0
eval_count: int = 0
load_duration: int = 0
Expand All @@ -28,6 +35,7 @@ def __repr__(self) -> str:
)

def to_json(self) -> dict:
""" Return a JSON representation of the tracked info """
return {
"prompt_eval_count": self.prompt_eval_count,
"eval_count": self.eval_count,
Expand Down Expand Up @@ -61,13 +69,11 @@ def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
pass

def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Print out the token."""
pass

def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
def on_llm_end(self, response: LLMResult, **_kwargs: Any) -> None:
"""Collect token usage."""
if len(response.generations) == 0 or len(response.generations[0]) == 0:
return
Expand All @@ -88,6 +94,19 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
with self._lock:
self.metrics.append(metrics)

def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when a chat model starts running. """

def __copy__(self) -> "OllamaCallbackHandler":
"""Return a copy of the callback handler."""
return self
Expand Down
15 changes: 11 additions & 4 deletions knowledge_base_gpt/libs/gpt/private_chat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Optional
"""
Module for handling the chat chain
"""
from typing import Optional, Dict, Any

from injector import inject, singleton
from langchain.chains import ConversationalRetrievalChain
Expand All @@ -12,8 +15,8 @@


@singleton
class PrivateChat():

class PrivateChat(): # pylint:disable=R0903
""" Handle the Chat chain """
@inject
def __init__(self, settings: Settings, chat_log_exporter: ChatLogExporter, vector_store: VectorStore):
llm_mode = settings.llm.mode
Expand Down Expand Up @@ -46,7 +49,11 @@ def __init__(self, settings: Settings, chat_log_exporter: ChatLogExporter, vecto
return_generated_question=True
)

def answer_query(self, history, query, chat_identifier: Optional[str]=None):
def answer_query(self, history, query, chat_identifier: Optional[str]=None) -> Dict[str, Any]:
"""
Answer the query based on the history
Use the chat identifier for logging the chat
"""
with self._get_callback() as cb:
answer = self._chain.invoke({"question": query, "chat_history": history})
self._chat_log_exporter.save_chat_log(self._chat_fragment_cls(answer, cb, chat_identifier=chat_identifier))
Expand Down
16 changes: 9 additions & 7 deletions knowledge_base_gpt/libs/history/base.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
""" Base class for history keepers """
from abc import ABC, abstractmethod
from typing import List
from typing import List, Dict, Any

from langchain_core.messages import BaseMessage


class HistoryBase(ABC):
""" Base class for history keepers """
@abstractmethod
def get_messages(self, session_id) -> List[BaseMessage]:
pass
def get_messages(self, session_id: str) -> List[BaseMessage]:
""" Get all messages of the session """

@abstractmethod
def add_to_history(self, session_id, answer):
pass
def add_to_history(self, session_id: str, answer: Dict[str, Any]):
""" Add the answer to the session """

@abstractmethod
def reset(self, session_id):
pass
def reset(self, session_id: str):
""" Reset the session """
3 changes: 2 additions & 1 deletion knowledge_base_gpt/libs/history/redis.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
""" Manage history in Redis """
from typing import List, Dict, Any

from injector import inject, singleton
Expand All @@ -10,7 +11,7 @@

@singleton
class HistoryRedis(HistoryBase):

""" Manage history in Redis """
@inject
def __init__(self, settings: Settings):
redis_settings = settings.redis
Expand Down
13 changes: 4 additions & 9 deletions knowledge_base_gpt/libs/injector/di.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
""" Global injector for the application """
from injector import Injector

from knowledge_base_gpt.libs.settings.settings import Settings, unsafe_typed_settings


def create_application_injector() -> Injector:
def _create_application_injector() -> Injector:
_injector = Injector(auto_bind=True)
_injector.binder.bind(Settings, to=unsafe_typed_settings)
return _injector


"""
Global injector for the application.
Avoid using this reference, it will make your code harder to test.
Instead, use the `request.state.injector` reference, which is bound to every request
"""
global_injector: Injector = create_application_injector()
# Global injector for the application.
global_injector: Injector = _create_application_injector()
7 changes: 5 additions & 2 deletions knowledge_base_gpt/libs/loaders/google_drive_loader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
""" Content loader from Google Drive """
from typing import List

from injector import inject, singleton
Expand All @@ -8,14 +9,16 @@


@singleton
class GDriveLoader():
class GDriveLoader(): # pylint:disable=R0903
""" Content loader from Google Drive """

@inject
def __init__(self, settings: Settings) -> None:
self._service_key_file = settings.google_drive.service_key_file
self._folder_id = settings.google_drive.folder_id

def load_documents(self, ignored_files: List[str] = []) -> List[Document]:
def load_documents(self, ignored_files: List[str]) -> List[Document]:
""" Load the documents based on the settings and the ignore list """
if not self._folder_id:
return []

Expand Down
7 changes: 5 additions & 2 deletions knowledge_base_gpt/libs/loaders/loaders.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
""" Abstract all content loaders """
from typing import List

from injector import inject, singleton
Expand All @@ -9,7 +10,8 @@


@singleton
class Loader():
class Loader(): # pylint:disable=R0903
""" Abstract all content loaders """

@inject
def __init__(self, settings: Settings) -> None:
Expand All @@ -22,5 +24,6 @@ def __init__(self, settings: Settings) -> None:
case _:
pass

def load_documents(self, ignored_files: List[str] = []) -> List[Document]:
def load_documents(self, ignored_files: List[str]) -> List[Document]:
""" Load all the documents based on the settings and the ignore list """
return self._content_loader.load_documents(ignored_files=ignored_files)
Loading

0 comments on commit ccf2a41

Please sign in to comment.