Skip to content

Commit

Permalink
faiss_embedding now supports caching
Browse files Browse the repository at this point in the history
  • Loading branch information
george1459 committed Apr 30, 2024
1 parent ee6078a commit 0c2212a
Showing 1 changed file with 52 additions and 2 deletions.
54 changes: 52 additions & 2 deletions src/suql/faiss_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
import hashlib
from collections import OrderedDict

import os
import faiss
import hashlib
import pickle
from FlagEmbedding import FlagModel
from flask import Flask, request
from tqdm import tqdm
from platformdirs import user_cache_dir

from suql.postgresql_connection import execute_sql
from suql.utils import chunk_text
Expand Down Expand Up @@ -48,6 +52,13 @@ def compute_sha256(text):
return hashlib.sha256(text.encode()).hexdigest()


def consistent_tuple_hash(tuple_input):
# Serialize the tuple to bytes using pickle
tuple_bytes = pickle.dumps(tuple_input, protocol=pickle.HIGHEST_PROTOCOL)
# Return the SHA-256 hash of the serialized bytes
return hashlib.sha256(tuple_bytes).hexdigest()


# A set that also preserves insertion order
class OrderedSet:
def __init__(self, iterable=None):
Expand Down Expand Up @@ -126,6 +137,8 @@ def __init__(
user="select_user",
password="select_user",
chunking_param=0,
cache_embedding=True,
force_recompute=False
) -> None:
# stores three lists:
# 1. PSQL primary key for each row
Expand All @@ -147,10 +160,17 @@ def __init__(
# stores PSQL login credentails
self.user = user
self.password = password

# store caching flag
assert type(cache_embedding) == bool
self.cache_embedding = cache_embedding
assert type(force_recompute) == bool
self.force_recompute = force_recompute

self.initialize_from_sql(
table_name, primary_key_field_name, free_text_field_name, db_name
)
print(f"initializing embeddings for DB: {db_name}; TABLE: {table_name}; FREE_TEXT_FIELD: {free_text_field_name}")
self.initialize_embedding()

def initialize_from_sql(
Expand Down Expand Up @@ -217,10 +237,36 @@ def initialize_from_sql(
else:
raise ValueError("Expecting type Str")

def compute_hash(self):
# Convert lists to tuples for hashing
psql_row_ids_tuple = tuple(self.psql_row_ids)
all_free_text_tuple = tuple(self.all_free_text)

# Create a combined tuple of all objects
combined_data = (psql_row_ids_tuple, all_free_text_tuple, self.chunking_param)

# Compute and return the hash of the combined tuple
return consistent_tuple_hash(combined_data)

def initialize_embedding(self):
print("initializing embeddings for all documents")
hash = self.compute_hash()
_user_cache_dir = user_cache_dir('suql')
faiss_cache_location = os.path.join(_user_cache_dir, f'{hash}.faiss_index')
if (os.path.exists(faiss_cache_location) and not self.force_recompute):
try:
print(f"initializing from existing faiss embedding index at {faiss_cache_location}")
self.embeddings = faiss.read_index(faiss_cache_location)
return
except Exception:
print(f"reading {faiss_cache_location} failed. Re-computing embeddings")

self.embeddings = faiss.IndexFlatIP(EMBEDDING_DIMENSION)
self.embeddings.add(embed_documents(self.chunked_text))
indexs = embed_documents(self.chunked_text)
self.embeddings.add(indexs)

print(f"writing computed faiss embedding to {faiss_cache_location}")
os.makedirs(_user_cache_dir, exist_ok=True)
faiss.write_index(self.embeddings, faiss_cache_location)

def dot_product(self, id_list, query, top, individual_id_list=[]):
# given a list of id and a particular query, return the top ids and documents according to similarity score ranking
Expand Down Expand Up @@ -367,6 +413,8 @@ def add(
user="select_user",
password="select_user",
chunking_param=0,
cache_embedding=True,
force_recompute=False
):
if (
table_name in self.mapping
Expand All @@ -388,6 +436,8 @@ def add(
user=user,
password=password,
chunking_param=chunking_param,
cache_embedding=cache_embedding,
force_recompute=force_recompute
)

def retrieve(self, table_name, free_text_field_name):
Expand Down

0 comments on commit 0c2212a

Please sign in to comment.