Skip to content
This repository has been archived by the owner on Nov 13, 2024. It is now read-only.

Add support for namespaces #243

Merged
merged 26 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/canopy/chat_engine/chat_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ def chat(self,
messages: Messages,
*,
stream: bool = False,
model_params: Optional[dict] = None
model_params: Optional[dict] = None,
namespace: Optional[str] = None
) -> Union[ChatResponse, StreamingChatResponse]:
"""
Chat completion with RAG. Given a list of messages (history), the chat engine will generate the next response, based on the relevant context retrieved from the knowledge base.
Expand All @@ -180,6 +181,7 @@ def chat(self,
messages: A list of messages (history) to generate the next response from.
stream: A boolean flag to indicate if the chat should be streamed or not. Defaults to False.
model_params: A dictionary of model parameters to use for the LLM. Defaults to None, which means the LLM will use its default values.
namespace: The namespace of the index for context retreival. To learn more about namespaces, see https://docs.pinecone.io/docs/namespaces

Returns:
A ChatResponse object if stream is False, or a StreamingChatResponse object if stream is True.
Expand All @@ -196,7 +198,7 @@ def chat(self,
>>> for chunk in response.chunks:
... print(chunk.json())
""" # noqa: E501
context = self._get_context(messages)
context = self._get_context(messages, namespace)
llm_messages = self._history_pruner.build(
chat_history=messages,
max_tokens=self.max_prompt_tokens,
Expand Down Expand Up @@ -227,9 +229,10 @@ def chat(self,

def _get_context(self,
messages: Messages,
namespace: Optional[str] = None
) -> Context:
queries = self._query_builder.generate(messages, self.max_prompt_tokens)
context = self.context_engine.query(queries, self.max_context_tokens)
context = self.context_engine.query(queries, self.max_context_tokens, namespace)
izellevy marked this conversation as resolved.
Show resolved Hide resolved
return context

async def achat(self,
Expand Down
8 changes: 6 additions & 2 deletions src/canopy/context_engine/context_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,16 @@ def __init__(self,

self.global_metadata_filter = global_metadata_filter

def query(self, queries: List[Query], max_context_tokens: int, ) -> Context:
def query(self, queries: List[Query],
max_context_tokens: int,
namespace: Optional[str] = None) -> Context:
izellevy marked this conversation as resolved.
Show resolved Hide resolved
"""
Query the knowledge base for relevant documents and build a context from the retrieved documents that can be injected into the LLM prompt.

Args:
queries: A list of queries to use for retrieving documents from the knowledge base
max_context_tokens: The maximum number of tokens to use for the context
namespace: The namespace of the index for context retreival. To learn more about namespaces, see https://docs.pinecone.io/docs/namespaces
izellevy marked this conversation as resolved.
Show resolved Hide resolved

Returns:
A Context object containing the retrieved documents and metadata
Expand All @@ -100,7 +103,8 @@ def query(self, queries: List[Query], max_context_tokens: int, ) -> Context:
""" # noqa: E501
query_results = self.knowledge_base.query(
queries,
global_metadata_filter=self.global_metadata_filter)
global_metadata_filter=self.global_metadata_filter,
namespace=namespace)
context = self.context_builder.build(query_results, max_context_tokens)

if CE_DEBUG_INFO:
Expand Down
3 changes: 2 additions & 1 deletion src/canopy/knowledge_base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ class BaseKnowledgeBase(ABC, ConfigurableMixin):
@abstractmethod
def query(self,
queries: List[Query],
global_metadata_filter: Optional[dict] = None
global_metadata_filter: Optional[dict] = None,
namespace: Optional[str] = None
) -> List[QueryResult]:
pass

Expand Down
16 changes: 10 additions & 6 deletions src/canopy/knowledge_base/knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from canopy.knowledge_base.reranker import Reranker, TransparentReranker
from canopy.models.data_models import Query, Document


INDEX_NAME_PREFIX = "canopy--"
TIMEOUT_INDEX_CREATE = 300
TIMEOUT_INDEX_PROVISION = 30
Expand Down Expand Up @@ -69,7 +68,6 @@ def list_canopy_indexes() -> List[str]:


class KnowledgeBase(BaseKnowledgeBase):

"""
The `KnowledgeBase` is used to store and retrieve text documents, using an underlying Pinecone index.
Every document is chunked into multiple text snippets based on the text structure (e.g. Markdown or HTML formatting)
Expand Down Expand Up @@ -401,7 +399,8 @@ def delete_index(self):

def query(self,
queries: List[Query],
global_metadata_filter: Optional[dict] = None
global_metadata_filter: Optional[dict] = None,
namespace: Optional[str] = None
) -> List[QueryResult]:
"""
Query the knowledge base to retrieve document chunks.
Expand All @@ -417,6 +416,8 @@ def query(self,
global_metadata_filter: A metadata filter to apply to all queries, in addition to any query-specific filters.
For example, the filter {"website": "wiki"} will only return documents with the metadata {"website": "wiki"} (in case provided in upsert)
see https://docs.pinecone.io/docs/metadata-filtering
namespace: The namespace of the index for context retreival. To learn more about namespaces, see https://docs.pinecone.io/docs/namespaces
izellevy marked this conversation as resolved.
Show resolved Hide resolved

Returns:
A list of QueryResult objects.

Expand All @@ -436,7 +437,9 @@ def query(self,
raise RuntimeError(self._connection_error_msg)

queries = self._encoder.encode_queries(queries)
results = [self._query_index(q, global_metadata_filter) for q in queries]
results = [self._query_index(q,
global_metadata_filter,
namespace) for q in queries]
results = self._reranker.rerank(results)

return [
Expand All @@ -455,7 +458,8 @@ def query(self,

def _query_index(self,
query: KBQuery,
global_metadata_filter: Optional[dict]) -> KBQueryResult:
global_metadata_filter: Optional[dict],
namespace: Optional[str] = None) -> KBQueryResult:
if self._index is None:
raise RuntimeError(self._connection_error_msg)

Expand All @@ -471,7 +475,7 @@ def _query_index(self,
result = self._index.query(vector=query.values,
sparse_vector=query.sparse_values,
top_k=top_k,
namespace=query.namespace,
namespace=namespace,
filter=metadata_filter,
include_metadata=True,
_check_return_type=_check_return_type,
Expand Down
4 changes: 0 additions & 4 deletions src/canopy/models/data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@

class Query(BaseModel):
text: str = Field(description="The query text.")
namespace: str = Field(
default="",
description="The namespace of the query. To learn more about namespaces, see https://docs.pinecone.io/docs/namespaces", # noqa: E501
)
metadata_filter: Optional[dict] = Field(
default=None,
description="A Pinecone metadata filter, to learn more about metadata filters, see https://docs.pinecone.io/docs/metadata-filtering", # noqa: E501
Expand Down
16 changes: 9 additions & 7 deletions src/canopy_server/app.py
izellevy marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
APIRouter
)
import uvicorn
from typing import cast, Union
from typing import cast, Union, Optional

from canopy.models.api_models import (
StreamingChatResponse,
Expand Down Expand Up @@ -95,13 +95,19 @@
logger: logging.Logger


izellevy marked this conversation as resolved.
Show resolved Hide resolved
@openai_api_router.post(
"/{namespace}/chat/completions",
response_model=None,
responses={500: {"description": "Failed to chat with Canopy"}}, # noqa: E501
)
@openai_api_router.post(
"/chat/completions",
response_model=None,
responses={500: {"description": "Failed to chat with Canopy"}}, # noqa: E501
)
async def chat(
request: ChatRequest = Body(...),
namespace: Optional[str] = None,
) -> APIChatResponse:
"""
Chat with Canopy, using the LLM and context engine, and return a response.
Expand All @@ -111,6 +117,7 @@ async def chat(

""" # noqa: E501
try:
logger.debug(f"The namespace is {namespace}")
session_id = request.user or "None" # noqa: F841
question_id = str(uuid.uuid4())
logger.debug(f"Received chat request: {request.messages[-1].content}")
Expand Down Expand Up @@ -288,13 +295,8 @@ async def startup():


def _init_routes(app):
# Include the application level router (health, shutdown, ...)
app.include_router(application_router, include_in_schema=False)
izellevy marked this conversation as resolved.
Show resolved Hide resolved
app.include_router(application_router, prefix=f"/{API_VERSION}")
# Include the API without version == latest
app.include_router(context_api_router, include_in_schema=False)
app.include_router(openai_api_router, include_in_schema=False)
# Include the API version in the path, API_VERSION should be the latest version.
app.include_router(application_router, prefix=f"/{API_VERSION}")
app.include_router(context_api_router, prefix=f"/{API_VERSION}", tags=["Context"])
app.include_router(openai_api_router, prefix=f"/{API_VERSION}", tags=["LLM"])

Expand Down
9 changes: 6 additions & 3 deletions tests/unit/chat_engine/test_chat_engine.py
izellevy marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def test_chat(self, history_length=5, snippet_length=10):
)
self.mock_context_engine.query.assert_called_once_with(
expected['queries'],
max_context_tokens=70
max_context_tokens=70,
namespace=None
)
self.mock_llm.chat_completion.assert_called_once_with(
system_prompt=expected['prompt'],
Expand Down Expand Up @@ -168,7 +169,8 @@ def test_chat_engine_params(self,
)
self.mock_context_engine.query.assert_called_once_with(
expected['queries'],
max_context_tokens=max_context_tokens
max_context_tokens=max_context_tokens,
namespace=None
)
self.mock_llm.chat_completion.assert_called_once_with(
system_prompt=expected['prompt'],
Expand Down Expand Up @@ -199,7 +201,8 @@ def test_get_context(self):
)
self.mock_context_engine.query.assert_called_once_with(
expected['queries'],
max_context_tokens=70
max_context_tokens=70,
namespace=None
)

assert isinstance(context, Context)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/context_engine/test_context_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_query(context_engine,

assert result == mock_context
mock_knowledge_base.query.assert_called_once_with(
queries, global_metadata_filter=None)
queries, global_metadata_filter=None, namespace=None)
mock_context_builder.build.assert_called_once_with(
mock_query_result, max_context_tokens)

Expand Down Expand Up @@ -112,7 +112,7 @@ def test_query_with_metadata_filter(context_engine,

assert result == mock_context
mock_knowledge_base.query.assert_called_once_with(
queries, global_metadata_filter=mock_global_metadata_filter)
queries, global_metadata_filter=mock_global_metadata_filter, namespace=None)
mock_context_builder.build.assert_called_once_with(
mock_query_result, max_context_tokens)

Expand Down