Skip to content

Commit

Permalink
[fix]: embedding fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
keenborder786 committed Dec 19, 2024
1 parent c823cc5 commit b854425
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions libs/community/langchain_community/embeddings/llamacpp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Optional
from typing import Any, List, Optional, final

from langchain_core.embeddings import Embeddings
from pydantic import BaseModel, ConfigDict, Field, model_validator
Expand Down Expand Up @@ -116,7 +116,14 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
List of embeddings, one for each text.
"""
embeddings = self.client.create_embedding(texts)
return [list(map(float, e["embedding"])) for e in embeddings["data"]]
if not isinstance(embeddings["data"][0]["embedding"][0], list):
return [list(map(float, e["embedding"])) for e in embeddings["data"]]
else:
final_embeddings = []
for e in embeddings["data"]:
for data in e["embedding"]:
final_embeddings.append(list(map(float, data)))
return final_embeddings

def embed_query(self, text: str) -> List[float]:
"""Embed a query using the Llama model.
Expand All @@ -128,4 +135,7 @@ def embed_query(self, text: str) -> List[float]:
Embeddings for the text.
"""
embedding = self.client.embed(text)
return list(map(float, embedding))
if not isinstance(embedding, list):
return list(map(float, embedding))
else:
return list(map(float, embedding[0]))

0 comments on commit b854425

Please sign in to comment.