From b828efa10944f24cb451ea1ed1791c13e1c52245 Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Tue, 31 Oct 2023 09:33:02 -0400 Subject: [PATCH] DH-4917/fix the embedding issue --- dataherald/sql_generator/dataherald_sqlagent.py | 8 +++++--- requirements.txt | 1 + 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 7d6a6cac..75ef484a 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -22,6 +22,7 @@ CallbackManagerForToolRun, ) from langchain.chains.llm import LLMChain +from langchain.embeddings import OpenAIEmbeddings from langchain.schema import AgentAction from langchain.tools.base import BaseTool from overrides import override @@ -205,9 +206,10 @@ def get_embedding( self, text: str, model: str = "text-embedding-ada-002" ) -> List[float]: text = text.replace("\n", " ") - return openai.Embedding.create(input=[text], model=model)["data"][0][ - "embedding" - ] + embedding = OpenAIEmbeddings( + openai_api_key=os.environ.get("OPENAI_API_KEY"), model=model + ) + return embedding.embed_query(text) def cosine_similarity(self, a: List[float], b: List[float]) -> float: return round(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)), 4) diff --git a/requirements.txt b/requirements.txt index efc79b52..f9561df6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,3 +34,4 @@ sphinx-book-theme==1.0.1 boto3==1.28.38 botocore==1.31.38 PyAthena==3.0.6 +tiktoken==0.5.1