diff --git a/documentation/docs/miscellaneous/advanced.md b/documentation/docs/miscellaneous/advanced.md index b2023c1b2..63682d495 100644 --- a/documentation/docs/miscellaneous/advanced.md +++ b/documentation/docs/miscellaneous/advanced.md @@ -7,9 +7,11 @@ sidebar_position: 3 ## Search across Different Languages (Self-Hosting) To search for notes in multiple, different languages, you can use a [multi-lingual model](https://www.sbert.net/docs/pretrained_models.html#multi-lingual-models).
For example, the [paraphrase-multilingual-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2) supports [50+ languages](https://www.sbert.net/docs/pretrained_models.html#:~:text=we%20used%20the%20following%2050%2B%20languages), has good search quality and speed. To use it: -1. Manually update the search config in server's admin settings page. Go to [the search config](http://localhost:42110/server/admin/database/searchmodelconfig/). Either create a new one, if none exists, or update the existing one. Set the bi_encoder to `sentence-transformers/multi-qa-MiniLM-L6-cos-v1` and the cross_encoder to `cross-encoder/ms-marco-MiniLM-L-6-v2`. +1. Manually update the search config in server's admin settings page. Go to [the search config](http://localhost:42110/server/admin/database/searchmodelconfig/). Either create a new one, if none exists, or update the existing one. Set the bi_encoder to `sentence-transformers/multi-qa-MiniLM-L6-cos-v1` and the cross_encoder to `mixedbread-ai/mxbai-rerank-xsmall-v1`. 2. Regenerate your content index from all the relevant clients. This step is very important, as you'll need to re-encode all your content with the new model. +Note: If you use a search model that expects a prefix (e.g [mixedbread-ai/mxbai-embed-large-v1](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1)) to the query (or docs) string before encoding. Update the `bi_encoder_query_encode_config` field with `{prompt: }`. Eg. `{prompt: "Represent this query for searching documents"}`. You can pass a valid JSON object that the SentenceTransformer `encode` function accepts + ## Query Filters Use structured query syntax to filter entries from your knowledge based used by search results or chat responses. diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 419bf9501..38b8223f2 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -216,6 +216,9 @@ def configure_server( model.bi_encoder, model.embeddings_inference_endpoint, model.embeddings_inference_endpoint_api_key, + query_encode_kwargs=model.bi_encoder_query_encode_config, + docs_encode_kwargs=model.bi_encoder_docs_encode_config, + model_kwargs=model.bi_encoder_model_config, ) } ) diff --git a/src/khoj/database/migrations/0037_searchmodelconfig_bi_encoder_docs_encode_config_and_more.py b/src/khoj/database/migrations/0037_searchmodelconfig_bi_encoder_docs_encode_config_and_more.py new file mode 100644 index 000000000..fc33e12b0 --- /dev/null +++ b/src/khoj/database/migrations/0037_searchmodelconfig_bi_encoder_docs_encode_config_and_more.py @@ -0,0 +1,32 @@ +# Generated by Django 4.2.10 on 2024-04-24 04:19 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0036_delete_offlinechatprocessorconversationconfig"), + ] + + operations = [ + migrations.AddField( + model_name="searchmodelconfig", + name="bi_encoder_docs_encode_config", + field=models.JSONField(default=dict), + ), + migrations.AddField( + model_name="searchmodelconfig", + name="bi_encoder_model_config", + field=models.JSONField(default=dict), + ), + migrations.AddField( + model_name="searchmodelconfig", + name="bi_encoder_query_encode_config", + field=models.JSONField(default=dict), + ), + migrations.AlterField( + model_name="searchmodelconfig", + name="cross_encoder", + field=models.CharField(default="mixedbread-ai/mxbai-rerank-xsmall-v1", max_length=200), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 15f396f14..ae13e9803 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -179,13 +179,27 @@ class SearchModelConfig(BaseModel): class ModelType(models.TextChoices): TEXT = "text" + # This is the model name exposed to users on their settings page name = models.CharField(max_length=200, default="default") + # Type of content the model can generate embeddings for model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT) + # Bi-encoder model of sentence-transformer type to load from HuggingFace bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small") - cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2") + # Config passed to the sentence-transformer model constructor. E.g device="cuda:0", trust_remote_server=True etc. + bi_encoder_model_config = models.JSONField(default=dict) + # Query encode configs like prompt, precision, normalize_embeddings, etc. for sentence-transformer models + bi_encoder_query_encode_config = models.JSONField(default=dict) + # Docs encode configs like prompt, precision, normalize_embeddings, etc. for sentence-transformer models + bi_encoder_docs_encode_config = models.JSONField(default=dict) + # Cross-encoder model of sentence-transformer type to load from HuggingFace + cross_encoder = models.CharField(max_length=200, default="mixedbread-ai/mxbai-rerank-xsmall-v1") + # Inference server API endpoint to use for embeddings inference. Bi-encoder model should be hosted on this server embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True) + # Inference server API Key to use for embeddings inference. Bi-encoder model should be hosted on this server embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True) + # Inference server API endpoint to use for embeddings inference. Cross-encoder model should be hosted on this server cross_encoder_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True) + # Inference server API Key to use for embeddings inference. Cross-encoder model should be hosted on this server cross_encoder_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True) diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index ec8e08f01..701bbfac8 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -13,7 +13,7 @@ ) from torch import nn -from khoj.utils.helpers import get_device +from khoj.utils.helpers import get_device, merge_dicts from khoj.utils.rawconfig import SearchResponse logger = logging.getLogger(__name__) @@ -25,9 +25,15 @@ def __init__( model_name: str = "thenlper/gte-small", embeddings_inference_endpoint: str = None, embeddings_inference_endpoint_api_key: str = None, + query_encode_kwargs: dict = {}, + docs_encode_kwargs: dict = {}, + model_kwargs: dict = {}, ): - self.encode_kwargs = {"normalize_embeddings": True} - self.model_kwargs = {"device": get_device()} + default_query_encode_kwargs = {"show_progress_bar": False, "normalize_embeddings": True} + default_docs_encode_kwargs = {"show_progress_bar": True, "normalize_embeddings": True} + self.query_encode_kwargs = merge_dicts(query_encode_kwargs, default_query_encode_kwargs) + self.docs_encode_kwargs = merge_dicts(docs_encode_kwargs, default_docs_encode_kwargs) + self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()}) self.model_name = model_name self.inference_endpoint = embeddings_inference_endpoint self.api_key = embeddings_inference_endpoint_api_key @@ -39,7 +45,7 @@ def inference_server_enabled(self) -> bool: def embed_query(self, query): if self.inference_server_enabled(): return self.embed_with_api([query])[0] - return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0] + return self.embeddings_model.encode([query], **self.query_encode_kwargs)[0] @retry( retry=retry_if_exception_type(requests.exceptions.HTTPError), @@ -70,7 +76,7 @@ def embed_documents(self, docs): logger.warning( f"Unsupported inference endpoint: {self.inference_endpoint}. Only HuggingFace supported. Generating embeddings on device instead." ) - return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist() + return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() # break up the docs payload in chunks of 1000 to avoid hitting rate limits embeddings = [] with tqdm.tqdm(total=len(docs)) as pbar: @@ -80,13 +86,13 @@ def embed_documents(self, docs): embeddings += generated_embeddings pbar.update(1000) return embeddings - return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist() + return self.embeddings_model.encode(docs, **self.docs_encode_kwargs).tolist() class CrossEncoderModel: def __init__( self, - model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", + model_name: str = "mixedbread-ai/mxbai-rerank-xsmall-v1", cross_encoder_inference_endpoint: str = None, cross_encoder_inference_endpoint_api_key: str = None, ): diff --git a/tests/helpers.py b/tests/helpers.py index 642f05ddc..686735967 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -75,7 +75,7 @@ class Meta: name = "default" model_type = "text" bi_encoder = "thenlper/gte-small" - cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2" + cross_encoder = "mixedbread-ai/mxbai-rerank-xsmall-v1" class SubscriptionFactory(factory.django.DjangoModelFactory):