diff --git a/tests/integration_tests/test_embedding.py b/tests/integration_tests/test_embedding.py index d68c3fd..60efcbd 100644 --- a/tests/integration_tests/test_embedding.py +++ b/tests/integration_tests/test_embedding.py @@ -24,3 +24,27 @@ def test_embeddings(logging_conf): print(err) except zhipuai.core._errors.APIStatusError as err: print(err) + + +def test_embeddings_dimensions(logging_conf): + logging.config.dictConfig(logging_conf) # type: ignore + + client = ZhipuAI() + try: + response = client.embeddings.create( + model="embedding-3", #填写需要调用的模型名称 + input="你好", + dimensions=512, + extra_body={"model_version": "v1"} + ) + assert response.data[0].object == "embedding" + assert len(response.data[0].embedding) == 512 + print(len(response.data[0].embedding)) + + + except zhipuai.core._errors.APIRequestFailedError as err: + print(err) + except zhipuai.core._errors.APIInternalError as err: + print(err) + except zhipuai.core._errors.APIStatusError as err: + print(err) diff --git a/zhipuai/api_resource/embeddings.py b/zhipuai/api_resource/embeddings.py index 4f5fb15..09d06c4 100644 --- a/zhipuai/api_resource/embeddings.py +++ b/zhipuai/api_resource/embeddings.py @@ -22,6 +22,7 @@ def create( *, input: Union[str, List[str], List[int], List[List[int]]], model: Union[str], + dimensions: Union[int], encoding_format: str | NotGiven = NOT_GIVEN, user: str | NotGiven = NOT_GIVEN, request_id: Optional[str] | NotGiven = NOT_GIVEN, @@ -39,6 +40,7 @@ def create( body={ "input": input, "model": model, + "dimensions": dimensions, "encoding_format": encoding_format, "user": user, "request_id": request_id,