diff --git a/libs/community/langchain_community/embeddings/llamacpp.py b/libs/community/langchain_community/embeddings/llamacpp.py index 1623ccf631a2f..49091b5aa6bfb 100644 --- a/libs/community/langchain_community/embeddings/llamacpp.py +++ b/libs/community/langchain_community/embeddings/llamacpp.py @@ -57,6 +57,9 @@ class LlamaCppEmbeddings(BaseModel, Embeddings): verbose: bool = Field(True, alias="verbose") """Print verbose output to stderr.""" + device: Optional[str] = Field(None, alias="device") + """Device type to use and pass to the model""" + class Config: extra = "forbid" @@ -75,6 +78,7 @@ def validate_environment(cls, values: Dict) -> Dict: "n_threads", "n_batch", "verbose", + "device", ] model_params = {k: values[k] for k in model_param_names} # For backwards compatibility, only include if non-null. @@ -108,8 +112,8 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: Returns: List of embeddings, one for each text. """ - embeddings = [self.client.embed(text) for text in texts] - return [list(map(float, e)) for e in embeddings] + embeddings = self.client.create_embedding(texts) + return [list(map(float, e["embedding"])) for e in embeddings["data"]] def embed_query(self, text: str) -> List[float]: """Embed a query using the Llama model.