Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create Embeddings and VectorStore #750

Open
ahuang11 opened this issue Nov 13, 2024 · 1 comment · May be fixed by #764
Open

Create Embeddings and VectorStore #750

ahuang11 opened this issue Nov 13, 2024 · 1 comment · May be fixed by #764
Assignees

Comments

@ahuang11
Copy link
Contributor

ahuang11 commented Nov 13, 2024

I am planning to refactor the existing Embeddings class.

The purpose is to supply LLMs with up-to-date or private data using retrieval-augmented generation (RAG) because LLMs are trained on static and sometimes outdated datasets, which may not provide accurate information.

Goals:

  • Create text embeddings to perform similarity searches.
    When a user makes a query, the system compares embeddings to find the most relevant texts, ensuring responses are more contextually relevant.

  • Allow users to plug in their own embedding models.
    Provide an interface (abstract base classes) so methods stay consistent regardless of which embedding model is used (e.g., OpenAI, HuggingFace, MistralAI, WordLlama).

  • Cache embeddings to enable fast lookups.
    Use vector stores, like NumPy arrays for in memory storage or DuckDB for persistent storage to store embeddings, speeding up similarity searches and improving overall performance.


I propose the following interfaces with minimal methods as to prevent the interface from being too rigid. I did not add delete to the interface because some of the stores, like numpy, are not persistent. DuckDb will have it though.

class Embeddings(ABC):
    @abstractmethod
    def embed(self, texts: List[str]) -> List[List[float]]:
        """Generate embeddings for a list of texts."""
        pass


class VectorStore(ABC):
    def __init__(self, embedding_model: 'Embeddings'):
        self.embedding_model = embedding_model

    @abstractmethod
    def add(self, texts: List[str], metadata: Optional[List[Dict]] = None) -> List[int]:
        """
        Add texts and their metadata to the store.
        Returns:
            List[int]: A list of unique text IDs for the added texts.
        """
        pass

    @abstractmethod
    def query(self, text: str, top_k: int = 5) -> List[Dict]:
        """
        Query store for similar texts.
        Returns:
            List[Dict]: List of matching texts with metadata and similarity scores.
        """
        pass

Then, I am planning to implement these:

OpenAI Embeddings
MistralAI Embeddings
HuggingFace Embeddings
WordLlama Embeddings

Example:

class OpenAIEmbeddings(Embeddings):
    def __init__(self, api_key: str, model: str = 'text-embedding-3-small'):
        from openai import OpenAI
        self.client = OpenAI()

    def embed(self, texts: List[str]) -> List[List[float]]:
        texts = [text.replace("\n", " ") for text in texts]
        response = self.client.embeddings.create(input=texts, model=self.model)
        return [r.embedding for r in response.data]

With these stores:

NumpyVectorStore (memory)
DuckDBVectorStore (persistent)
WordLlamaVectorStore (memory)
ChromaVectorStore (variety)

Example:

import duckdb
import json
from typing import List, Optional, Dict

class DuckDBVectorStore(VectorStore):
    def __init__(self, embedding_model: 'Embeddings', db_path: str = ':memory:'):
        super().__init__(embedding_model)
        self.connection = duckdb.connect(database=db_path)
        self._setup_database()

    def _setup_database(self) -> None:
        self.connection.execute("""
            CREATE TABLE IF NOT EXISTS documents (
                id BIGINT AUTO_INCREMENT PRIMARY KEY,
                text VARCHAR,
                embedding FLOAT[],
                table_name VARCHAR,
                metadata JSON
            );
        """)
        self.connection.execute("""
            CREATE INDEX IF NOT EXISTS embedding_index 
            ON documents USING HNSW (embedding) WITH (metric = 'cosine');
        """)

    def add(self, texts: List[str], metadata: Optional[List[Dict]] = None, table_name: str = "default") -> List[int]:
        embeddings = self.embedding_model.embed(texts)
        text_ids = []
        for i, (text, embedding) in enumerate(zip(texts, embeddings)):
            meta = metadata[i] if metadata else {}
            result = self.connection.execute("""
                INSERT INTO documents (text, embedding, table_name, metadata)
                VALUES (?, ?, ?, ?) RETURNING id;
            """, [text, embedding, table_name, json.dumps(meta)])
            text_ids.append(result.fetchone()[0])  # Fetch and collect the generated IDs
        return text_ids

    def delete(self, text_ids: List[int]) -> None:
        self.connection.execute("""
            DELETE FROM documents WHERE id IN ?;
        """, (tuple(text_ids),))

    def query(self, text: str, top_k: int = 5, table_name: Optional[str] = None) -> List[Dict]:
        query_embedding = self.embedding_model.embed([text])[0]
        if table_name:
            result = self.connection.execute("""
                SELECT id, text, metadata,
                       cosine_distance(embedding, ?) AS similarity
                FROM documents
                WHERE table_name = ?
                ORDER BY similarity ASC
                LIMIT ?;
            """, [query_embedding, table_name, top_k]).fetchall()
        else:
            result = self.connection.execute("""
                SELECT id, text, metadata,
                       cosine_distance(embedding, ?) AS similarity
                FROM documents
                ORDER BY similarity ASC
                LIMIT ?;
            """, [query_embedding, top_k]).fetchall()

        return [{"id": row[0], "text": row[1], "metadata": json.loads(row[2]), "similarity": row[3]} for row in result]

    def lookup_text_ids(self, texts: Optional[List[str]] = None, metadata: Optional[Dict] = None) -> List[int]:
        query = "SELECT id FROM documents WHERE 1=1"
        params = []

        if texts:
            query += " AND text IN ?"
            params.append(tuple(texts))

        if metadata:
            query += " AND metadata @> ?"
            params.append(json.dumps(metadata))

        result = self.connection.execute(query, params).fetchall()
        return [row[0] for row in result]
@philippjfr
Copy link
Member

def add(self, texts: List[str], metadata: Optional[List[Dict]] = None, table_name: str = "default") -> List[int]:

Maybe combine text and metadata into one list and then let's just shove the table_name into the metadata for now.

@ahuang11 ahuang11 linked a pull request Nov 18, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants