From 00f0cc0795e3ea72d8e099c14410910459fc3306 Mon Sep 17 00:00:00 2001 From: Fantix King Date: Wed, 1 May 2024 10:52:34 -0400 Subject: [PATCH 1/5] Extract _make_rag_request() --- edgedb/ai/core.py | 49 ++++++++++++++++++++++++----------------------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/edgedb/ai/core.py b/edgedb/ai/core.py index e7fd0700..00145430 100644 --- a/edgedb/ai/core.py +++ b/edgedb/ai/core.py @@ -85,6 +85,23 @@ def with_context(self, **kwargs) -> typing.Self: rv.client = self.client return rv + def _make_rag_request( + self, + *, + message: str, + context: typing.Optional[types.QueryContext] = None, + stream: bool, + ) -> types.RAGRequest: + if context is None: + context = self.context + return types.RAGRequest( + model=self.options.model, + prompt=self.options.prompt, + context=context, + query=message, + stream=stream, + ) + class EdgeDBAI(BaseEdgeDBAI): client: httpx.Client @@ -95,14 +112,10 @@ def _init_client(self, **kwargs): def query_rag( self, message: str, context: typing.Optional[types.QueryContext] = None ) -> str: - if context is None: - context = self.context resp = self.client.post( - **types.RAGRequest( - model=self.options.model, - prompt=self.options.prompt, + **self._make_rag_request( context=context, - query=message, + message=message, stream=False, ).to_httpx_request() ) @@ -112,16 +125,12 @@ def query_rag( def stream_rag( self, message: str, context: typing.Optional[types.QueryContext] = None ): - if context is None: - context = self.context with httpx_sse.connect_sse( self.client, "post", - **types.RAGRequest( - model=self.options.model, - prompt=self.options.prompt, + **self._make_rag_request( context=context, - query=message, + message=message, stream=True, ).to_httpx_request(), ) as event_source: @@ -139,14 +148,10 @@ def _init_client(self, **kwargs): async def query_rag( self, message: str, context: typing.Optional[types.QueryContext] = None ) -> str: - if context is None: - context = self.context resp = await self.client.post( - **types.RAGRequest( - model=self.options.model, - prompt=self.options.prompt, + **self._make_rag_request( context=context, - query=message, + message=message, stream=False, ).to_httpx_request() ) @@ -156,16 +161,12 @@ async def query_rag( async def stream_rag( self, message: str, context: typing.Optional[types.QueryContext] = None ): - if context is None: - context = self.context async with httpx_sse.aconnect_sse( self.client, "post", - **types.RAGRequest( - model=self.options.model, - prompt=self.options.prompt, + **self._make_rag_request( context=context, - query=message, + message=message, stream=True, ).to_httpx_request(), ) as event_source: From 6642a980ddc5c2b8096868ff46e211120a084b06 Mon Sep 17 00:00:00 2001 From: Fantix King Date: Wed, 1 May 2024 12:49:30 -0400 Subject: [PATCH 2/5] Expose embeddings API --- edgedb/ai/core.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/edgedb/ai/core.py b/edgedb/ai/core.py index 00145430..6c093540 100644 --- a/edgedb/ai/core.py +++ b/edgedb/ai/core.py @@ -138,6 +138,13 @@ def stream_rag( for sse in event_source.iter_sse(): yield sse.data + def generate_custom_embeddings(self, *inputs: str, model: str): + resp = self.client.post( + "/embeddings", json={"input": inputs, "model": model} + ) + resp.raise_for_status() + return resp.json() + class AsyncEdgeDBAI(BaseEdgeDBAI): client: httpx.AsyncClient @@ -173,3 +180,10 @@ async def stream_rag( event_source.response.raise_for_status() async for sse in event_source.aiter_sse(): yield sse.data + + async def generate_custom_embeddings(self, *inputs: str, model: str): + resp = await self.client.post( + "/embeddings", json={"input": inputs, "model": model} + ) + resp.raise_for_status() + return resp.json() From 8d0d913864879531c4c7e72fb1b90c91562ccdcb Mon Sep 17 00:00:00 2001 From: Fantix King Date: Wed, 1 May 2024 14:42:22 -0400 Subject: [PATCH 3/5] CRF: rename to `generate_embeddings()` --- edgedb/ai/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/edgedb/ai/core.py b/edgedb/ai/core.py index 6c093540..ccb19c2d 100644 --- a/edgedb/ai/core.py +++ b/edgedb/ai/core.py @@ -138,7 +138,7 @@ def stream_rag( for sse in event_source.iter_sse(): yield sse.data - def generate_custom_embeddings(self, *inputs: str, model: str): + def generate_embeddings(self, *inputs: str, model: str): resp = self.client.post( "/embeddings", json={"input": inputs, "model": model} ) @@ -181,7 +181,7 @@ async def stream_rag( async for sse in event_source.aiter_sse(): yield sse.data - async def generate_custom_embeddings(self, *inputs: str, model: str): + async def generate_embeddings(self, *inputs: str, model: str): resp = await self.client.post( "/embeddings", json={"input": inputs, "model": model} ) From aa3c755423e52cc984c46636e0d011fdf4794d23 Mon Sep 17 00:00:00 2001 From: Fantix King Date: Wed, 1 May 2024 15:07:34 -0400 Subject: [PATCH 4/5] CRF: unpack embedding result --- edgedb/ai/core.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/edgedb/ai/core.py b/edgedb/ai/core.py index ccb19c2d..b9bcd76a 100644 --- a/edgedb/ai/core.py +++ b/edgedb/ai/core.py @@ -124,7 +124,7 @@ def query_rag( def stream_rag( self, message: str, context: typing.Optional[types.QueryContext] = None - ): + ) -> typing.Iterator[str]: with httpx_sse.connect_sse( self.client, "post", @@ -138,12 +138,14 @@ def stream_rag( for sse in event_source.iter_sse(): yield sse.data - def generate_embeddings(self, *inputs: str, model: str): + def generate_embeddings( + self, *inputs: str, model: str + ) -> list[list[float]]: resp = self.client.post( "/embeddings", json={"input": inputs, "model": model} ) resp.raise_for_status() - return resp.json() + return [data["embedding"] for data in resp.json()["data"]] class AsyncEdgeDBAI(BaseEdgeDBAI): @@ -167,7 +169,7 @@ async def query_rag( async def stream_rag( self, message: str, context: typing.Optional[types.QueryContext] = None - ): + ) -> typing.Iterator[str]: async with httpx_sse.aconnect_sse( self.client, "post", @@ -181,9 +183,11 @@ async def stream_rag( async for sse in event_source.aiter_sse(): yield sse.data - async def generate_embeddings(self, *inputs: str, model: str): + async def generate_embeddings( + self, *inputs: str, model: str + ) -> list[list[float]]: resp = await self.client.post( "/embeddings", json={"input": inputs, "model": model} ) resp.raise_for_status() - return resp.json() + return [data["embedding"] for data in resp.json()["data"]] From cc6c9a919eda900d3a9eb3c1ac39ba86b2291ff1 Mon Sep 17 00:00:00 2001 From: Fantix King Date: Tue, 28 May 2024 10:19:48 -0400 Subject: [PATCH 5/5] Unbox embedding data --- edgedb/ai/core.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/edgedb/ai/core.py b/edgedb/ai/core.py index b9bcd76a..69fe235d 100644 --- a/edgedb/ai/core.py +++ b/edgedb/ai/core.py @@ -138,14 +138,12 @@ def stream_rag( for sse in event_source.iter_sse(): yield sse.data - def generate_embeddings( - self, *inputs: str, model: str - ) -> list[list[float]]: + def generate_embeddings(self, *inputs: str, model: str) -> list[float]: resp = self.client.post( "/embeddings", json={"input": inputs, "model": model} ) resp.raise_for_status() - return [data["embedding"] for data in resp.json()["data"]] + return resp.json()["data"][0]["embedding"] class AsyncEdgeDBAI(BaseEdgeDBAI): @@ -185,9 +183,9 @@ async def stream_rag( async def generate_embeddings( self, *inputs: str, model: str - ) -> list[list[float]]: + ) -> list[float]: resp = await self.client.post( "/embeddings", json={"input": inputs, "model": model} ) resp.raise_for_status() - return [data["embedding"] for data in resp.json()["data"]] + return resp.json()["data"][0]["embedding"]