Skip to content

Commit

Permalink
improve llamacpp embeddings (#12972)
Browse files Browse the repository at this point in the history
- **Description:**
Improve llamacpp embedding class by adding the `device` parameter so it
can be passed to the model and used with `gpu`, `cpu` or Apple metal
(`mps`).
Improve performance by making use of the bulk client api to compute
embeddings in batches.
  
  - **Dependencies:** none
  - **Tag maintainer:** 
@hwchase17

---------

Co-authored-by: Harrison Chase <[email protected]>
Co-authored-by: Chester Curme <[email protected]>
  • Loading branch information
3 people authored Aug 31, 2024
1 parent f882824 commit 654da27
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions libs/community/langchain_community/embeddings/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
verbose: bool = Field(True, alias="verbose")
"""Print verbose output to stderr."""

device: Optional[str] = Field(None, alias="device")
"""Device type to use and pass to the model"""

class Config:
extra = "forbid"

Expand All @@ -75,6 +78,7 @@ def validate_environment(cls, values: Dict) -> Dict:
"n_threads",
"n_batch",
"verbose",
"device",
]
model_params = {k: values[k] for k in model_param_names}
# For backwards compatibility, only include if non-null.
Expand Down Expand Up @@ -108,8 +112,8 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
Returns:
List of embeddings, one for each text.
"""
embeddings = [self.client.embed(text) for text in texts]
return [list(map(float, e)) for e in embeddings]
embeddings = self.client.create_embedding(texts)
return [list(map(float, e["embedding"])) for e in embeddings["data"]]

def embed_query(self, text: str) -> List[float]:
"""Embed a query using the Llama model.
Expand Down

0 comments on commit 654da27

Please sign in to comment.