From ed62984cb24e32ef209780c48ab1103a06bc66e7 Mon Sep 17 00:00:00 2001 From: Simon Dai Date: Thu, 19 Oct 2023 00:49:30 -0700 Subject: [PATCH] update Weaviate to support multi tenancy (#11842) - **Description:** update Weaviate to support multi tenancy - **Issue:** 9956 - **Dependencies:** - **Tag maintainer:** hwchase17 - **Twitter handle:** dsx1986_ --- libs/langchain/langchain/vectorstores/weaviate.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/libs/langchain/langchain/vectorstores/weaviate.py b/libs/langchain/langchain/vectorstores/weaviate.py index 85162cd8b6942..83973bb1d9499 100644 --- a/libs/langchain/langchain/vectorstores/weaviate.py +++ b/libs/langchain/langchain/vectorstores/weaviate.py @@ -209,6 +209,8 @@ def similarity_search_by_text( query_obj = self._client.query.get(self._index_name, self._query_attrs) if kwargs.get("where_filter"): query_obj = query_obj.with_where(kwargs.get("where_filter")) + if kwargs.get("tenant"): + query_obj = query_obj.with_tenant(kwargs.get("tenant")) if kwargs.get("additional"): query_obj = query_obj.with_additional(kwargs.get("additional")) result = query_obj.with_near_text(content).with_limit(k).do() @@ -228,6 +230,8 @@ def similarity_search_by_vector( query_obj = self._client.query.get(self._index_name, self._query_attrs) if kwargs.get("where_filter"): query_obj = query_obj.with_where(kwargs.get("where_filter")) + if kwargs.get("tenant"): + query_obj = query_obj.with_tenant(kwargs.get("tenant")) if kwargs.get("additional"): query_obj = query_obj.with_additional(kwargs.get("additional")) result = query_obj.with_near_vector(vector).with_limit(k).do() @@ -304,6 +308,8 @@ def max_marginal_relevance_search_by_vector( query_obj = self._client.query.get(self._index_name, self._query_attrs) if kwargs.get("where_filter"): query_obj = query_obj.with_where(kwargs.get("where_filter")) + if kwargs.get("tenant"): + query_obj = query_obj.with_tenant(kwargs.get("tenant")) results = ( query_obj.with_additional("vector") .with_near_vector(vector) @@ -343,6 +349,8 @@ def similarity_search_with_score( query_obj = self._client.query.get(self._index_name, self._query_attrs) if kwargs.get("where_filter"): query_obj = query_obj.with_where(kwargs.get("where_filter")) + if kwargs.get("tenant"): + query_obj = query_obj.with_tenant(kwargs.get("tenant")) embedded_query = self._embedding.embed_query(query) if not self._by_text: