Skip to content

Commit

Permalink
format block
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelfeil committed Dec 18, 2023
1 parent 6566ff5 commit ae10c18
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 20 deletions.
31 changes: 18 additions & 13 deletions libs/community/langchain_community/embeddings/gradient_ai.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
import asyncio
import logging
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Dict, List, Optional, Tuple

import aiohttp
import numpy as np
import requests
from typing import Any, Dict, List, Optional

from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.utils import get_from_dict_or_env
from packaging.version import parse

__all__ = ["GradientEmbeddings"]

Expand Down Expand Up @@ -48,7 +42,7 @@ class GradientEmbeddings(BaseModel, Embeddings):

gradient_api_url: Optional[str] = None
"""Endpoint URL to use."""

query_for_retrieval: Optional[str] = None
"""Endpoint URL to use."""

Expand Down Expand Up @@ -78,7 +72,14 @@ def validate_environment(cls, values: Dict) -> Dict:
try:
import gradientai
except ImportError:
raise ImportError("GradientEmbeddings requires `pip install gradientai`.")
raise ImportError(
"GradientEmbeddings requires `pip install gradientai>=1.4.0`."
)

if parse(gradientai.__version__) < parse("1.4.0"):
raise ImportError(
"GradientEmbeddings requires `pip install gradientai>=1.4.0`."
)

gradient = gradientai.Gradient(
access_token=values["gradient_access_token"],
Expand Down Expand Up @@ -128,7 +129,9 @@ def embed_query(self, text: str) -> List[float]:
Returns:
Embeddings for the text.
"""
query = f"{self.query_for_retrieval} {text}" if self.query_for_retrieval else text
query = (
f"{self.query_for_retrieval} {text}" if self.query_for_retrieval else text
)
return self.embed_documents([query])[0]

async def aembed_query(self, text: str) -> List[float]:
Expand All @@ -140,6 +143,8 @@ async def aembed_query(self, text: str) -> List[float]:
Returns:
Embeddings for the text.
"""
query = f"{self.query_for_retrieval} {text}" if self.query_for_retrieval else text
query = (
f"{self.query_for_retrieval} {text}" if self.query_for_retrieval else text
)
embeddings = await self.aembed_documents([query])
return embeddings[0]
16 changes: 9 additions & 7 deletions libs/community/tests/unit_tests/embeddings/test_gradient_ai.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import List, Any, Dict
import sys
from typing import Any, Dict, List
from unittest.mock import MagicMock, patch

import pytest
from unittest.mock import MagicMock, patch
import sys

from langchain_community.embeddings import GradientEmbeddings

_MODEL_ID = "my_model_valid_id"
Expand All @@ -18,14 +19,13 @@
]



class GradientEmbeddingsModel(MagicMock):
"""MockGradientModel."""

def embed(self, inputs: List[Dict[str,str]]) -> Any:
def embed(self, inputs: List[Dict[str, str]]) -> Any:
"""Just duplicate the query m times."""
output = MagicMock()

embeddings = []
for i, inp in enumerate(inputs):
# verify correct ordering
Expand Down Expand Up @@ -70,6 +70,7 @@ class MockGradientaiPackage(MagicMock):
"""Mock Gradientai package."""

Gradient = MockGradient
__version__ = "1.4.0"


def test_gradient_llm_sync() -> None:
Expand Down Expand Up @@ -97,7 +98,6 @@ def test_gradient_llm_sync() -> None:
assert response == want



def test_gradient_wrong_setup() -> None:
with pytest.raises(Exception):
GradientEmbeddings(
Expand All @@ -107,6 +107,7 @@ def test_gradient_wrong_setup() -> None:
model=_MODEL_ID,
)


def test_gradient_wrong_setup2() -> None:
with pytest.raises(Exception):
GradientEmbeddings(
Expand All @@ -116,6 +117,7 @@ def test_gradient_wrong_setup2() -> None:
model=_MODEL_ID,
)


def test_gradient_wrong_setup3() -> None:
with pytest.raises(Exception):
GradientEmbeddings(
Expand Down

0 comments on commit ae10c18

Please sign in to comment.