generated from cheshire-cat-ai/plugin-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from AlessandroSpallina/develop
Welcome to v0.2.0
- Loading branch information
Showing
5 changed files
with
124 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,70 +1,134 @@ | ||
import sqlite3 | ||
import os | ||
import hashlib | ||
from typing import List | ||
from cat.log import log | ||
from cat.mad_hatter.decorators import tool, hook | ||
from pydantic import BaseModel | ||
from langchain.indexes import SQLRecordManager, index | ||
from langchain.docstore.document import Document | ||
from langchain.vectorstores import Qdrant | ||
from cat.mad_hatter.decorators import hook | ||
|
||
# TODO: use settings instead of hard coded db path | ||
# class DietSettings(BaseModel): | ||
# sqlite_file_path: str = "/app/cat/plugins/ccat-dietician/diet.db" | ||
from typing import List | ||
from typing import Optional | ||
from sqlalchemy import ForeignKey, MetaData | ||
from sqlalchemy import String | ||
from sqlalchemy.orm import DeclarativeBase | ||
from sqlalchemy.orm import Mapped | ||
from sqlalchemy.orm import mapped_column | ||
from sqlalchemy.orm import relationship | ||
from sqlalchemy import create_engine | ||
from sqlalchemy.orm import Session | ||
|
||
|
||
class Base(DeclarativeBase): | ||
pass | ||
|
||
|
||
class DietDocument(Base): | ||
__tablename__= 'document' | ||
id: Mapped[int] = mapped_column(primary_key=True) | ||
name: Mapped[str] = mapped_column(String(256), unique=True) | ||
hash: Mapped[str] = mapped_column(String(64), unique=True) | ||
|
||
chunks: Mapped[List["Chunk"]] = relationship(back_populates="document") | ||
|
||
|
||
# @hook | ||
# def plugin_settings_schema(): | ||
# return DietSettings.schema() | ||
def __repr__(self) -> str: | ||
return f'DietDocument(name={self.name!r}, hash={self.hash!r})' | ||
|
||
|
||
# Hook called when a list of Document is going to be inserted in memory from the rabbit hole. | ||
# Here you can edit/summarize the documents before inserting them in memory | ||
# Should return a list of documents (each is a langchain Document) | ||
@hook | ||
def before_rabbithole_stores_documents(docs: List[Document], cat) -> List[Document]: | ||
"""Hook into the memory insertion pipeline. | ||
class Chunk(Base): | ||
__tablename__ = 'chunk' | ||
id: Mapped[int] = mapped_column(primary_key=True) | ||
chunk_count: Mapped[int] | ||
document_id: Mapped[int] = mapped_column(ForeignKey("document.id")) | ||
|
||
Allows modifying how the list of `Document` is inserted in the vector memory. | ||
document: Mapped["DietDocument"] = relationship(back_populates="chunks") | ||
|
||
For example, this hook is a good point to summarize the incoming documents and save both original and | ||
summarized contents. | ||
An official plugin is available to test this procedure. | ||
def __repr__(self) -> str: | ||
return f'Chunk(chunk_count={self.chunk_count!r})' | ||
|
||
Parameters | ||
---------- | ||
docs : List[Document] | ||
List of Langchain `Document` to be edited. | ||
cat: CheshireCat | ||
Cheshire Cat instance. | ||
|
||
Returns | ||
------- | ||
docs : List[Document] | ||
List of edited Langchain documents. | ||
engine = create_engine(f"sqlite:///cat/plugins/ccat-dietician/dietician.db") | ||
|
||
""" | ||
Base.metadata.create_all(engine, checkfirst=True) | ||
|
||
vector_db = cat.memory.vectors.vector_db | ||
embedder = cat.embedder | ||
|
||
q = Qdrant(vector_db, "declarative", embedder) | ||
|
||
record_manager = SQLRecordManager( | ||
namespace="qdrant/declarative", | ||
db_url="sqlite:////app/cat/plugins/ccat-dietician/diet.db" | ||
) | ||
@hook(priority=10) | ||
def before_rabbithole_splits_text(doc, cat): | ||
# doc is a list with only one element, always | ||
cat.working_memory['ccat-dietician'] = { | ||
'name': doc[0].metadata['source'], | ||
'hash': hashlib.sha256(doc[0].page_content.encode()).hexdigest() | ||
} | ||
|
||
record_manager.create_schema() | ||
return doc | ||
|
||
ret = index( | ||
docs, | ||
record_manager, | ||
q, | ||
delete_mode="incremental", | ||
source_id_key="source" | ||
) | ||
|
||
log(f"Dietist: index return is {ret}", "DEBUG") | ||
# Hook called when a list of Document is going to be inserted in memory from the rabbit hole. | ||
# Here you can edit/summarize the documents before inserting them in memory | ||
# Should return a list of documents (each is a langchain Document) | ||
@hook(priority=10) | ||
def before_rabbithole_stores_documents(docs: List[Document], cat) -> List[Document]: | ||
cat.working_memory['ccat-dietician']['chunk_count'] = len(docs) | ||
|
||
#document = session.query(DietDocument).filter_by(hash=hash).first() | ||
|
||
with Session(engine) as session: | ||
try: | ||
doc_by_name = session.query(DietDocument).filter_by(name=cat.working_memory['ccat-dietician']['name']).first() | ||
if doc_by_name is None: | ||
doc_by_hash = session.query(DietDocument).filter_by(hash=cat.working_memory['ccat-dietician']['hash']).first() | ||
if doc_by_hash is None: | ||
db_doc = DietDocument(name=cat.working_memory['ccat-dietician']['name'], hash=cat.working_memory['ccat-dietician']['hash'], chunks=[Chunk(chunk_count=cat.working_memory['ccat-dietician']['chunk_count'])]) | ||
session.add(db_doc) | ||
session.commit() | ||
log.info(f"Dietician is allowing the ingestion of a new document: {db_doc}") | ||
return docs | ||
else: | ||
if cat.working_memory['ccat-dietician']['chunk_count'] in [c.chunk_count for c in doc_by_hash.chunks]: | ||
log.info(f"Dietician detected {cat.working_memory['ccat-dietician']['name']} as a duplicate of {doc_by_hash.name}, since the number of chunks ({cat.working_memory['ccat-dietician']['chunk_count']}) coincides to what is already in declarative memory, this ingestion is going to be avoided.") | ||
return [] | ||
else: | ||
doc_by_hash.chunks.append(Chunk(chunk_count=cat.working_memory['ccat-dietician']['chunk_count'])) | ||
session.add(doc_by_hash) | ||
session.commit() | ||
log.info(f"Dietician detected {cat.working_memory['ccat-dietician']['name']} as a duplicate of {doc_by_hash.name}, since the number of chunks ({cat.working_memory['ccat-dietician']['chunk_count']}) produced now is different from what is already in declarative memory, this ingestion is going to be allowed.") | ||
return docs | ||
else: | ||
if cat.working_memory['ccat-dietician']['hash'] == doc_by_name.hash: | ||
if cat.working_memory['ccat-dietician']['chunk_count'] in [c.chunk_count for c in doc_by_name.chunks]: | ||
log.info(f"Dietician detected that {doc_by_name.name} was already ingested, since the number of chunks ({cat.working_memory['ccat-dietician']['chunk_count']}) coincides to what is already in declarative memory, this ingestion is going to be avoided.") | ||
return [] | ||
else: | ||
doc_by_name.chunks.append(Chunk(chunk_count=cat.working_memory['ccat-dietician']['chunk_count'])) | ||
session.add(doc_by_name) | ||
session.commit() | ||
log.info(f"Dietician detected that {doc_by_name.name} was already ingested, since the number of chunks ({cat.working_memory['ccat-dietician']['chunk_count']}) produced now is different from what is already in declarative memory, this ingestion is going to be allowed.") | ||
return docs | ||
else: | ||
old_chunks, _ = cat.memory.vectors.declarative.client.scroll( | ||
collection_name=cat.memory.vectors.declarative.collection_name, | ||
scroll_filter=cat.memory.vectors.declarative._qdrant_filter_from_dict({'source': doc_by_name.name}), | ||
with_payload=True | ||
) | ||
old_chunks_text = [c.payload['page_content'] for c in old_chunks] | ||
new_chunks_text = [d.page_content for d in docs] | ||
|
||
# we have to delete all chunks in declarative memory that are not in the new document because those chunks are related an old version of the document | ||
old_chunks_to_delete_ids = [c.id for c in old_chunks if c.payload['page_content'] not in new_chunks_text] | ||
|
||
if len(old_chunks_to_delete_ids) > 0: | ||
cat.memory.vectors.declarative.delete_points(old_chunks_to_delete_ids) | ||
|
||
log.info(f"Dietician detected an hash change for the document {doc_by_name}, this means that the document has beed updated. Allowing the ingestion of new chunks and deleting all the old chunks not any more present in the current document version.") | ||
|
||
# docs contain only chunks never inserted in declarative memory, we keep into the vectordb any chunk previously inserted (to avoid unnecessary calls to the embedding model) | ||
return [d for d in docs if d.page_content not in old_chunks_text] | ||
|
||
except Exception as e: | ||
session.rollback() | ||
log.error(f"Something weird happened: {str(e)}. Dietician is preventing the ingestion of {cat.working_memory['ccat-dietician']['name']}") | ||
return [] | ||
|
||
return [] |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters