-
Notifications
You must be signed in to change notification settings - Fork 16.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
standard-tests: add tools standard tests
- Loading branch information
Showing
2 changed files
with
113 additions
and
0 deletions.
There are no files selected for viewing
43 changes: 43 additions & 0 deletions
43
libs/standard-tests/langchain_standard_tests/integration_tests/tools.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from langchain_core.tools import BaseTool | ||
|
||
from langchain_standard_tests.unit_tests.tools import ToolsTests | ||
|
||
|
||
class ToolsIntegrationTests(ToolsTests): | ||
def test_invoke_matches_output_schema(self, tool: BaseTool) -> None: | ||
result = tool.invoke(self.tool_invoke_params_example) | ||
|
||
if tool.response_format == "content": | ||
content = result | ||
elif tool.response_format == "content_and_artifact": | ||
# should be (content, artifact) | ||
assert isinstance(result, tuple) | ||
assert len(result) == 2 | ||
content, artifact = result | ||
|
||
assert artifact # artifact can be anything, but shouldn't be none | ||
|
||
# check content is a valid ToolMessage content | ||
assert isinstance(content, (str, list)) | ||
if isinstance(content, list): | ||
# content blocks must be str or dict | ||
assert all(isinstance(c, (str, dict)) for c in content) | ||
|
||
async def test_async_invoke_matches_output_schema(self, tool: BaseTool) -> None: | ||
result = await tool.ainvoke(self.tool_invoke_params_example) | ||
|
||
if tool.response_format == "content": | ||
content = result | ||
elif tool.response_format == "content_and_artifact": | ||
# should be (content, artifact) | ||
assert isinstance(result, tuple) | ||
assert len(result) == 2 | ||
content, artifact = result | ||
|
||
assert artifact # artifact can be anything, but shouldn't be none | ||
|
||
# check content is a valid ToolMessage content | ||
assert isinstance(content, (str, list)) | ||
if isinstance(content, list): | ||
# content blocks must be str or dict | ||
assert all(isinstance(c, (str, dict)) for c in content) |
70 changes: 70 additions & 0 deletions
70
libs/standard-tests/langchain_standard_tests/unit_tests/tools.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import os | ||
from abc import abstractmethod | ||
from typing import Callable, Tuple, Type, Union | ||
from unittest import mock | ||
|
||
import pytest | ||
from langchain_core.tools import BaseTool | ||
from pydantic import SecretStr | ||
|
||
from langchain_standard_tests.base import BaseStandardTests | ||
|
||
|
||
class ToolsTests(BaseStandardTests): | ||
@property | ||
@abstractmethod | ||
def tool_constructor(self) -> Union[Type[BaseTool], Callable]: ... | ||
|
||
@property | ||
def tool_constructor_params(self) -> dict: | ||
return {} | ||
|
||
@property | ||
def tool_invoke_params_example(self) -> dict: | ||
return {} | ||
|
||
@pytest.fixture | ||
def tool(self) -> BaseTool: | ||
return self.tool_constructor(**self.tool_constructor_params) | ||
|
||
|
||
class ToolsUnitTests(ToolsTests): | ||
def test_init(self) -> None: | ||
tool = self.tool_constructor(**self.tool_constructor_params) | ||
assert tool is not None | ||
|
||
@property | ||
def init_from_env_params(self) -> Tuple[dict, dict, dict]: | ||
"""Return env vars, init args, and expected instance attrs for initializing | ||
from env vars.""" | ||
return {}, {}, {} | ||
|
||
def test_init_from_env(self) -> None: | ||
env_params, tools_params, expected_attrs = self.init_from_env_params | ||
if env_params: | ||
with mock.patch.dict(os.environ, env_params): | ||
tool = self.tool_constructor(**tools_params) | ||
assert tool is not None | ||
for k, expected in expected_attrs.items(): | ||
actual = getattr(tool, k) | ||
if isinstance(actual, SecretStr): | ||
actual = actual.get_secret_value() | ||
assert actual == expected | ||
|
||
def test_has_name(self, tool: BaseTool) -> None: | ||
assert tool.name | ||
|
||
def test_has_description(self, tool: BaseTool) -> None: | ||
assert tool.description | ||
|
||
def test_has_input_schema(self, tool: BaseTool) -> None: | ||
assert tool.get_input_schema() | ||
|
||
def test_input_schema_matches_invoke_params(self, tool: BaseTool) -> None: | ||
""" | ||
Tests that the provided example params match the declared input schema | ||
""" | ||
# this will be a pydantic object | ||
input_schema = tool.get_input_schema() | ||
|
||
assert input_schema(**self.tool_invoke_params_example) |