diff --git a/README.md b/README.md index b2df79e5..ab7c7617 100644 --- a/README.md +++ b/README.md @@ -40,9 +40,11 @@ The latest version of the package contains the following experiments: | ------------------------ | ----------------------- | ------------------- | | [`EvaluationHarness`][1] | Evaluation orchestrator | September 2024 | | [`OpenAIFunctionCaller`][2] | Function Calling Component | September 2024 | +| [`OpenAPITool`][3] | OpenAPITool component | September 2024 | [1]: https://github.com/deepset-ai/haystack-experimental/tree/main/haystack_experimental/evaluation/harness [2]: https://github.com/deepset-ai/haystack-experimental/tree/main/haystack_experimental/components/tools/openai +[3]: https://github.com/deepset-ai/haystack-experimental/tree/main/haystack_experimental/components/tools/openapi ## Usage diff --git a/haystack_experimental/components/tools/openapi/__init__.py b/haystack_experimental/components/tools/openapi/__init__.py new file mode 100644 index 00000000..6867e62b --- /dev/null +++ b/haystack_experimental/components/tools/openapi/__init__.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from haystack_experimental.components.tools.openapi.openapi_tool import OpenAPITool +from haystack_experimental.components.tools.openapi.types import LLMProvider + +__all__ = ["LLMProvider", "OpenAPITool"] diff --git a/haystack_experimental/components/tools/openapi/_openapi.py b/haystack_experimental/components/tools/openapi/_openapi.py new file mode 100644 index 00000000..2b13f70e --- /dev/null +++ b/haystack_experimental/components/tools/openapi/_openapi.py @@ -0,0 +1,341 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import logging +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional + +import requests + +from haystack_experimental.components.tools.openapi._payload_extraction import ( + create_function_payload_extractor, +) +from haystack_experimental.components.tools.openapi._schema_conversion import ( + anthropic_converter, + cohere_converter, + openai_converter, +) +from haystack_experimental.components.tools.openapi.types import LLMProvider, OpenAPISpecification, Operation + +MIN_REQUIRED_OPENAPI_SPEC_VERSION = 3 +logger = logging.getLogger(__name__) + + +def send_request(request: Dict[str, Any]) -> Dict[str, Any]: + """ + Send an HTTP request and return the response. + + :param request: The request to send. + :returns: The response from the server. + """ + url = request["url"] + headers = {**request.get("headers", {})} + try: + response = requests.request( + request["method"], + url, + headers=headers, + params=request.get("params", {}), + json=request.get("json"), + auth=request.get("auth"), + timeout=30, + ) + response.raise_for_status() + return response.json() + except requests.exceptions.HTTPError as e: + logger.warning("HTTP error occurred: %s while sending request to %s", e, url) + raise HttpClientError(f"HTTP error occurred: {e}") from e + except requests.exceptions.RequestException as e: + logger.warning("Request error occurred: %s while sending request to %s", e, url) + raise HttpClientError(f"HTTP error occurred: {e}") from e + except Exception as e: + logger.warning("An error occurred: %s while sending request to %s", e, url) + raise HttpClientError(f"An error occurred: {e}") from e + + +# Authentication strategies +def create_api_key_auth_function(api_key: str) -> Callable[[Dict[str, Any], Dict[str, Any]], None]: + """ + Create a function that applies the API key authentication strategy to a given request. + + :param api_key: the API key to use for authentication. + :returns: a function that applies the API key authentication to a request + at the schema specified location. + """ + + def apply_auth(security_scheme: Dict[str, Any], request: Dict[str, Any]) -> None: + """ + Apply the API key authentication strategy to the given request. + + :param security_scheme: the security scheme from the OpenAPI spec. + :param request: the request to apply the authentication to. + """ + if security_scheme["in"] == "header": + request.setdefault("headers", {})[security_scheme["name"]] = api_key + elif security_scheme["in"] == "query": + request.setdefault("params", {})[security_scheme["name"]] = api_key + elif security_scheme["in"] == "cookie": + request.setdefault("cookies", {})[security_scheme["name"]] = api_key + else: + raise ValueError( + f"Unsupported apiKey authentication location: {security_scheme['in']}, " + f"must be one of 'header', 'query', or 'cookie'" + ) + + return apply_auth + + +def create_http_auth_function(token: str) -> Callable[[Dict[str, Any], Dict[str, Any]], None]: + """ + Create a function that applies the http authentication strategy to a given request. + + :param token: the authentication token to use. + :returns: a function that applies the API key authentication to a request + at the schema specified location. + """ + + def apply_auth(security_scheme: Dict[str, Any], request: Dict[str, Any]) -> None: + """ + Apply the HTTP authentication strategy to the given request. + + :param security_scheme: the security scheme from the OpenAPI spec. + :param request: the request to apply the authentication to. + """ + if security_scheme["type"] == "http": + # support bearer http auth, no basic support yet + if security_scheme["scheme"].lower() == "bearer": + if not token: + raise ValueError("Token must be provided for Bearer Auth.") + request.setdefault("headers", {})[ + "Authorization" + ] = f"Bearer {token}" + else: + raise ValueError( + f"Unsupported HTTP authentication scheme: {security_scheme['scheme']}" + ) + else: + raise ValueError( + "HTTPAuthentication strategy received a non-HTTP security scheme." + ) + + return apply_auth + + +class HttpClientError(Exception): + """Exception raised for errors in the HTTP client.""" + + +class ClientConfiguration: + """Configuration for the OpenAPI client.""" + + def __init__( + self, + openapi_spec: OpenAPISpecification, + credentials: Optional[str] = None, + request_sender: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + llm_provider: LLMProvider = LLMProvider.OPENAI, + ): # noqa: PLR0913 + """ + Initialize a ClientConfiguration instance. + + :param openapi_spec: The OpenAPI specification to use for the client. + :param credentials: The credentials to use for authentication. + :param request_sender: The function to use for sending requests. + :param llm_provider: The LLM provider to use for generating tools definitions. + :raises ValueError: If the OpenAPI specification format is invalid. + """ + self.openapi_spec = openapi_spec + self.credentials = credentials + self.request_sender = request_sender or send_request + self.llm_provider: LLMProvider = llm_provider + + def get_auth_function(self) -> Callable[[Dict[str, Any], Dict[str, Any]], Any]: + """ + Get the authentication function that sets a schema specified authentication to the request. + + The function takes a security scheme and a request as arguments: + `security_scheme: Dict[str, Any] - The security scheme from the OpenAPI spec.` + `request: Dict[str, Any] - The request to apply the authentication to.` + :returns: The authentication function. + :raises ValueError: If the credentials type is not supported. + """ + security_schemes = self.openapi_spec.get_security_schemes() + if not self.credentials: + return lambda security_scheme, request: None # No-op function + if isinstance(self.credentials, str): + return self._create_authentication_from_string( + self.credentials, security_schemes + ) + raise ValueError(f"Unsupported credentials type: {type(self.credentials)}") + + def get_tools_definitions(self) -> List[Dict[str, Any]]: + """ + Get the tools definitions used as tools LLM parameter. + + :returns: The tools definitions passed to the LLM as tools parameter. + """ + provider_to_converter = defaultdict( + lambda: openai_converter, + { + LLMProvider.ANTHROPIC: anthropic_converter, + LLMProvider.COHERE: cohere_converter, + } + ) + converter = provider_to_converter[self.llm_provider] + return converter(self.openapi_spec) + + def get_payload_extractor(self) -> Callable[[Dict[str, Any]], Dict[str, Any]]: + """ + Get the payload extractor for the LLM provider. + + This function knows how to extract the exact function payload from the LLM generated function calling payload. + :returns: The payload extractor function. + """ + provider_to_arguments_field_name = defaultdict( + lambda: "arguments", + { + LLMProvider.ANTHROPIC: "input", + LLMProvider.COHERE: "parameters", + } + ) + arguments_field_name = provider_to_arguments_field_name[self.llm_provider] + return create_function_payload_extractor(arguments_field_name) + + def _create_authentication_from_string( + self, credentials: str, security_schemes: Dict[str, Any] + ) -> Callable[[Dict[str, Any], Dict[str, Any]], Any]: + for scheme in security_schemes.values(): + if scheme["type"] == "apiKey": + return create_api_key_auth_function(api_key=credentials) + if scheme["type"] == "http": + return create_http_auth_function(token=credentials) + raise ValueError( + f"Unsupported authentication type '{scheme['type']}' provided." + ) + raise ValueError( + f"Unable to create authentication from provided credentials: {credentials}" + ) + + +def build_request(operation: Operation, **kwargs) -> Dict[str, Any]: + """ + Build an HTTP request for the operation. + + :param operation: The operation to build the request for. + :param kwargs: The arguments to use for building the request. + :returns: The HTTP request as a dictionary. + :raises ValueError: If a required parameter is missing. + :raises NotImplementedError: If the request body content type is not supported. We only support JSON payloads. + """ + path = operation.path + for parameter in operation.get_parameters("path"): + param_value = kwargs.get(parameter["name"], None) + if param_value: + path = path.replace(f"{{{parameter['name']}}}", str(param_value)) + elif parameter.get("required", False): + raise ValueError(f"Missing required path parameter: {parameter['name']}") + url = operation.get_server() + path + # method + method = operation.method.lower() + # headers + headers = {} + for parameter in operation.get_parameters("header"): + param_value = kwargs.get(parameter["name"], None) + if param_value: + headers[parameter["name"]] = str(param_value) + elif parameter.get("required", False): + raise ValueError(f"Missing required header parameter: {parameter['name']}") + # query params + query_params = {} + for parameter in operation.get_parameters("query"): + param_value = kwargs.get(parameter["name"], None) + if param_value: + query_params[parameter["name"]] = param_value + elif parameter.get("required", False): + raise ValueError(f"Missing required query parameter: {parameter['name']}") + + json_payload = None + request_body = operation.request_body + if request_body: + content = request_body.get("content", {}) + if "application/json" in content: + json_payload = {**kwargs} + else: + raise NotImplementedError("Request body content type not supported") + return { + "url": url, + "method": method, + "headers": headers, + "params": query_params, + "json": json_payload, + } + + +def apply_authentication( + auth_strategy: Callable[[Dict[str, Any], Dict[str, Any]], Any], + operation: Operation, + request: Dict[str, Any], +): + """ + Apply the authentication strategy to the given request. + + :param auth_strategy: The authentication strategy to apply. + This is a function that takes a security scheme and a request as arguments (at runtime) + and applies the authentication + :param operation: The operation to apply the authentication to. + :param request: The request to apply the authentication to. + """ + security_requirements = operation.security_requirements + security_schemes = operation.spec_dict.get("components", {}).get( + "securitySchemes", {} + ) + if security_requirements: + for requirement in security_requirements: + for scheme_name in requirement: + if scheme_name in security_schemes: + security_scheme = security_schemes[scheme_name] + auth_strategy(security_scheme, request) + break + + +class OpenAPIServiceClient: + """ + A client for invoking operations on REST services defined by OpenAPI specifications. + """ + + def __init__(self, client_config: ClientConfiguration): + self.client_config = client_config + + def invoke(self, function_payload: Any) -> Any: + """ + Invokes a function specified in the function payload. + + :param function_payload: The function payload containing the details of the function to be invoked. + :returns: The response from the service after invoking the function. + :raises OpenAPIClientError: If the function invocation payload cannot be extracted from the function payload. + :raises HttpClientError: If an error occurs while sending the request and receiving the response. + """ + fn_invocation_payload = {} + try: + fn_extractor = self.client_config.get_payload_extractor() + fn_invocation_payload = fn_extractor(function_payload) + except Exception as e: + raise OpenAPIClientError( + f"Error extracting function invocation payload: {str(e)}" + ) from e + + if "name" not in fn_invocation_payload or "arguments" not in fn_invocation_payload: + raise OpenAPIClientError( + f"Function invocation payload does not contain 'name' or 'arguments' keys: {fn_invocation_payload}, " + f"the payload extraction function may be incorrect." + ) + # fn_invocation_payload, if not empty, guaranteed to have "name" and "arguments" keys from here on + operation = self.client_config.openapi_spec.find_operation_by_id(fn_invocation_payload["name"]) + request = build_request(operation, **fn_invocation_payload["arguments"]) + apply_authentication(self.client_config.get_auth_function(), operation, request) + return self.client_config.request_sender(request) + + +class OpenAPIClientError(Exception): + """Exception raised for errors in the OpenAPI client.""" diff --git a/haystack_experimental/components/tools/openapi/_payload_extraction.py b/haystack_experimental/components/tools/openapi/_payload_extraction.py new file mode 100644 index 00000000..6247c56a --- /dev/null +++ b/haystack_experimental/components/tools/openapi/_payload_extraction.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +import json +from typing import Any, Callable, Dict, List, Optional, Union + + +def create_function_payload_extractor( + arguments_field_name: str, +) -> Callable[[Any], Dict[str, Any]]: + """ + Extracts invocation payload from a given LLM completion containing function invocation. + + :param arguments_field_name: The name of the field containing the function arguments. + :return: A function that extracts the function invocation details from the LLM payload. + """ + + def _extract_function_invocation(payload: Any) -> Dict[str, Any]: + """ + Extract the function invocation details from the payload. + + :param payload: The LLM fc payload to extract the function invocation details from. + """ + fields_and_values = _search(payload, arguments_field_name) + if fields_and_values: + arguments = fields_and_values.get(arguments_field_name) + if not isinstance(arguments, (str, dict)): + raise ValueError( + f"Invalid {arguments_field_name} type {type(arguments)} for function call, expected str/dict" + ) + return { + "name": fields_and_values.get("name"), + "arguments": ( + json.loads(arguments) if isinstance(arguments, str) else arguments + ), + } + return {} + + return _extract_function_invocation + + +def _get_dict_converter( + obj: Any, method_names: Optional[List[str]] = None +) -> Union[Callable[[], Dict[str, Any]], None]: + method_names = method_names or [ + "model_dump", + "dict", + ] # search for pydantic v2 then v1 + for attr in method_names: + if hasattr(obj, attr) and callable(getattr(obj, attr)): + return getattr(obj, attr) + return None + + +def _is_primitive(obj) -> bool: + return isinstance(obj, (int, float, str, bool, type(None))) + + +def _required_fields(arguments_field_name: str) -> List[str]: + return ["name", arguments_field_name] + + +def _search(payload: Any, arguments_field_name: str) -> Dict[str, Any]: + if _is_primitive(payload): + return {} + if dict_converter := _get_dict_converter(payload): + payload = dict_converter() + elif dataclasses.is_dataclass(payload): + payload = dataclasses.asdict(payload) + if isinstance(payload, dict): + if all(field in payload for field in _required_fields(arguments_field_name)): + # this is the payload we are looking for + return payload + for value in payload.values(): + result = _search(value, arguments_field_name) + if result: + return result + elif isinstance(payload, list): + for item in payload: + result = _search(item, arguments_field_name) + if result: + return result + return {} diff --git a/haystack_experimental/components/tools/openapi/_schema_conversion.py b/haystack_experimental/components/tools/openapi/_schema_conversion.py new file mode 100644 index 00000000..1ed05152 --- /dev/null +++ b/haystack_experimental/components/tools/openapi/_schema_conversion.py @@ -0,0 +1,283 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import Any, Callable, Dict, List, Optional + +import jsonref + +from haystack_experimental.components.tools.openapi.types import OpenAPISpecification + +MIN_REQUIRED_OPENAPI_SPEC_VERSION = 3 + +logger = logging.getLogger(__name__) + + +def openai_converter(schema: OpenAPISpecification) -> List[Dict[str, Any]]: + """ + Converts OpenAPI specification to a list of function suitable for OpenAI LLM function calling. + + See https://platform.openai.com/docs/guides/function-calling for more information about OpenAI's function schema. + :param schema: The OpenAPI specification to convert. + :returns: A list of dictionaries, each dictionary representing an OpenAI function definition. + """ + resolved_schema = jsonref.replace_refs(schema.spec_dict) + fn_definitions = _openapi_to_functions( + resolved_schema, "parameters", _parse_endpoint_spec_openai + ) + return [{"type": "function", "function": fn} for fn in fn_definitions] + + +def anthropic_converter(schema: OpenAPISpecification) -> List[Dict[str, Any]]: + """ + Converts an OpenAPI specification to a list of function definitions for Anthropic LLM function calling. + + See https://docs.anthropic.com/en/docs/tool-use for more information about Anthropic's function schema. + + :param schema: The OpenAPI specification to convert. + :returns: A list of dictionaries, each dictionary representing Anthropic function definition. + """ + resolved_schema = jsonref.replace_refs(schema.spec_dict) + return _openapi_to_functions( + resolved_schema, "input_schema", _parse_endpoint_spec_openai + ) + + +def cohere_converter(schema: OpenAPISpecification) -> List[Dict[str, Any]]: + """ + Converts an OpenAPI specification to a list of function definitions for Cohere LLM function calling. + + See https://docs.cohere.com/docs/tool-use for more information about Cohere's function schema. + + :param schema: The OpenAPI specification to convert. + :returns: A list of dictionaries, each representing a Cohere style function definition. + """ + resolved_schema = jsonref.replace_refs(schema.spec_dict) + return _openapi_to_functions( + resolved_schema, "not important for cohere", _parse_endpoint_spec_cohere + ) + + +def _openapi_to_functions( + service_openapi_spec: Dict[str, Any], + parameters_name: str, + parse_endpoint_fn: Callable[[Dict[str, Any], str], Dict[str, Any]], +) -> List[Dict[str, Any]]: + """ + Extracts functions from the OpenAPI specification, converts them into a function schema. + + :param service_openapi_spec: The OpenAPI specification to extract functions from. + :param parameters_name: The name of the parameters field in the function schema. + :param parse_endpoint_fn: The function to parse the endpoint specification. + :returns: A list of dictionaries, each dictionary representing a function schema. + """ + + # Doesn't enforce rigid spec validation because that would require a lot of dependencies + # We check the version and require minimal fields to be present, so we can extract functions + spec_version = service_openapi_spec.get("openapi") + if not spec_version: + raise ValueError( + f"Invalid OpenAPI spec provided. Could not extract version from {service_openapi_spec}" + ) + service_openapi_spec_version = int(spec_version.split(".")[0]) + # Compare the versions + if service_openapi_spec_version < MIN_REQUIRED_OPENAPI_SPEC_VERSION: + raise ValueError( + f"Invalid OpenAPI spec version {service_openapi_spec_version}. Must be " + f"at least {MIN_REQUIRED_OPENAPI_SPEC_VERSION}." + ) + functions: List[Dict[str, Any]] = [] + for paths in service_openapi_spec["paths"].values(): + for path_spec in paths.values(): + function_dict = parse_endpoint_fn(path_spec, parameters_name) + if function_dict: + functions.append(function_dict) + return functions + + +def _parse_endpoint_spec_openai( + resolved_spec: Dict[str, Any], parameters_name: str +) -> Dict[str, Any]: + """ + Parses an OpenAPI endpoint specification for OpenAI. + + :param resolved_spec: The resolved OpenAPI specification. + :param parameters_name: The name of the parameters field in the function schema. + :returns: A dictionary containing the parsed function schema. + """ + if not isinstance(resolved_spec, dict): + logger.warning( + "Invalid OpenAPI spec format provided. Could not extract function." + ) + return {} + function_name = resolved_spec.get("operationId") + description = resolved_spec.get("description") or resolved_spec.get("summary", "") + schema: Dict[str, Any] = {"type": "object", "properties": {}} + # requestBody section + req_body_schema = ( + resolved_spec.get("requestBody", {}) + .get("content", {}) + .get("application/json", {}) + .get("schema", {}) + ) + if "properties" in req_body_schema: + for prop_name, prop_schema in req_body_schema["properties"].items(): + schema["properties"][prop_name] = _parse_property_attributes(prop_schema) + if "required" in req_body_schema: + schema.setdefault("required", []).extend(req_body_schema["required"]) + + # parameters section + for param in resolved_spec.get("parameters", []): + if "schema" in param: + schema_dict = _parse_property_attributes(param["schema"]) + # these attributes are not in param[schema] level but on param level + useful_attributes = ["description", "pattern", "enum"] + schema_dict.update( + {key: param[key] for key in useful_attributes if param.get(key)} + ) + schema["properties"][param["name"]] = schema_dict + if param.get("required", False): + schema.setdefault("required", []).append(param["name"]) + + if function_name and description and schema["properties"]: + return { + "name": function_name, + "description": description, + parameters_name: schema, + } + logger.warning( + "Invalid OpenAPI spec format provided. Could not extract function from %s", + resolved_spec, + ) + return {} + + +def _parse_property_attributes( + property_schema: Dict[str, Any], include_attributes: Optional[List[str]] = None +) -> Dict[str, Any]: + """ + Recursively parses the attributes of a property schema. + + :param property_schema: The property schema to parse. + :param include_attributes: The attributes to include in the parsed schema. + :returns: A dictionary containing the parsed property schema. + """ + include_attributes = include_attributes or ["description", "pattern", "enum"] + schema_type = property_schema.get("type") + parsed_schema = {"type": schema_type} if schema_type else {} + for attr in include_attributes: + if attr in property_schema: + parsed_schema[attr] = property_schema[attr] + if schema_type == "object": + properties = property_schema.get("properties", {}) + parsed_properties = { + prop_name: _parse_property_attributes(prop, include_attributes) + for prop_name, prop in properties.items() + } + parsed_schema["properties"] = parsed_properties + if "required" in property_schema: + parsed_schema["required"] = property_schema["required"] + elif schema_type == "array": + items = property_schema.get("items", {}) + parsed_schema["items"] = _parse_property_attributes(items, include_attributes) + return parsed_schema + + +def _parse_endpoint_spec_cohere( + operation: Dict[str, Any], ignored_param: str +) -> Dict[str, Any]: + """ + Parses an endpoint specification for Cohere. + + :param operation: The operation specification to parse. + :param ignored_param: ignored, left for compatibility with the OpenAI converter. + :returns: A dictionary containing the parsed function schema. + """ + function_name = operation.get("operationId") + description = operation.get("description") or operation.get("summary", "") + parameter_definitions = _parse_parameters(operation) + if function_name: + return { + "name": function_name, + "description": description, + "parameter_definitions": parameter_definitions, + } + logger.warning("Operation missing operationId, cannot create function definition.") + return {} + + +def _parse_parameters(operation: Dict[str, Any]) -> Dict[str, Any]: + """ + Parses the parameters from an operation specification. + + :param operation: The operation specification to parse. + :returns: A dictionary containing the parsed parameters. + """ + parameters = {} + for param in operation.get("parameters", []): + if "schema" in param: + parameters[param["name"]] = _parse_schema( + param["schema"], + param.get("required", False), + param.get("description", ""), + ) + if "requestBody" in operation: + content = ( + operation["requestBody"].get("content", {}).get("application/json", {}) + ) + if "schema" in content: + schema_properties = content["schema"].get("properties", {}) + required_properties = content["schema"].get("required", []) + for name, schema in schema_properties.items(): + parameters[name] = _parse_schema( + schema, name in required_properties, schema.get("description", "") + ) + return parameters + + +def _parse_schema( + schema: Dict[str, Any], required: bool, description: str +) -> Dict[str, Any]: # noqa: FBT001 + """ + Parses a schema part of an operation specification. + + :param schema: The schema to parse. + :param required: Whether the schema is required. + :param description: The description of the schema. + :returns: A dictionary containing the parsed schema. + """ + schema_type = _get_type(schema) + if schema_type == "object": + # Recursive call for complex types + properties = schema.get("properties", {}) + nested_parameters = { + name: _parse_schema( + schema=prop_schema, + required=bool(name in schema.get("required", [])), + description=prop_schema.get("description", ""), + ) + for name, prop_schema in properties.items() + } + return { + "type": schema_type, + "description": description, + "properties": nested_parameters, + "required": required, + } + return {"type": schema_type, "description": description, "required": required} + + +def _get_type(schema: Dict[str, Any]) -> str: + type_mapping = { + "integer": "int", + "string": "str", + "boolean": "bool", + "number": "float", + "object": "object", + "array": "list", + } + schema_type = schema.get("type", "object") + if schema_type not in type_mapping: + raise ValueError(f"Unsupported schema type {schema_type}") + return type_mapping[schema_type] diff --git a/haystack_experimental/components/tools/openapi/openapi_tool.py b/haystack_experimental/components/tools/openapi/openapi_tool.py new file mode 100644 index 00000000..33c64d4a --- /dev/null +++ b/haystack_experimental/components/tools/openapi/openapi_tool.py @@ -0,0 +1,217 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json +import os +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +from haystack import component, default_from_dict, default_to_dict, logging +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.dataclasses import ChatMessage, ChatRole +from haystack.lazy_imports import LazyImport +from haystack.utils import Secret, deserialize_secrets_inplace +from haystack.utils.url_validation import is_valid_http_url + +from haystack_experimental.components.tools.openapi._openapi import ( + ClientConfiguration, + OpenAPIServiceClient, +) +from haystack_experimental.components.tools.openapi.types import LLMProvider, OpenAPISpecification +from haystack_experimental.util import serialize_secrets_inplace + +with LazyImport("Run 'pip install anthropic-haystack'") as anthropic_import: + # pylint: disable=import-error + from haystack_integrations.components.generators.anthropic import AnthropicChatGenerator + +with LazyImport("Run 'pip install cohere-haystack'") as cohere_import: + # pylint: disable=import-error + from haystack_integrations.components.generators.cohere import CohereChatGenerator + +logger = logging.getLogger(__name__) + + +@component +class OpenAPITool: + """ + The OpenAPITool calls a RESTful endpoint of an OpenAPI service using payloads generated from human instructions. + + Here is an example of how to use the OpenAPITool component to scrape a URL using the FireCrawl API: + + ```python + from haystack.dataclasses import ChatMessage + from haystack_experimental.components.tools.openapi import OpenAPITool, LLMProvider + from haystack.utils import Secret + + tool = OpenAPITool(generator_api=LLMProvider.OPENAI, + generator_api_params={"model":"gpt-3.5-turbo"}, + spec="https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json", + credentials=Secret.from_token("")) + + results = tool.run(messages=[ChatMessage.from_user("Scrape URL: https://news.ycombinator.com/")]) + print(results) + ``` + + Similarly, you can use the OpenAPITool component to invoke **any** OpenAPI service/tool by providing the OpenAPI + specification and credentials. + """ + + def __init__( + self, + generator_api: LLMProvider, + generator_api_params: Optional[Dict[str, Any]] = None, + spec: Optional[Union[str, Path]] = None, + credentials: Optional[Secret] = None, + ): + """ + Initialize the OpenAPITool component. + + :param generator_api: The API provider for the chat generator. + :param generator_api_params: Parameters to pass for the chat generator creation. + :param spec: OpenAPI specification for the tool/service. This can be a URL, a local file path, or + an OpenAPI service specification provided as a string. + :param credentials: Credentials for the tool/service. + """ + self.generator_api = generator_api + self.generator_api_params = generator_api_params or {} # store the generator API parameters for serialization + self.chat_generator = self._init_generator(generator_api, generator_api_params or {}) + self.config_openapi: Optional[ClientConfiguration] = None + self.open_api_service: Optional[OpenAPIServiceClient] = None + self.spec = spec # store the spec for serialization + self.credentials = credentials # store the credentials for serialization + if spec: + if os.path.isfile(spec): + openapi_spec = OpenAPISpecification.from_file(spec) + elif is_valid_http_url(str(spec)): + openapi_spec = OpenAPISpecification.from_url(str(spec)) + else: + raise ValueError(f"Invalid OpenAPI specification source {spec}. Expected valid file path or URL") + self.config_openapi = ClientConfiguration( + openapi_spec=openapi_spec, + credentials=credentials.resolve_value() if credentials else None, + llm_provider=generator_api, + ) + self.open_api_service = OpenAPIServiceClient(self.config_openapi) + + @component.output_types(service_response=List[ChatMessage]) + def run( + self, + messages: List[ChatMessage], + fc_generator_kwargs: Optional[Dict[str, Any]] = None, + spec: Optional[Union[str, Path]] = None, + credentials: Optional[Secret] = None, + ) -> Dict[str, List[ChatMessage]]: + """ + Invokes the underlying OpenAPI service/tool with the function calling payload generated by the chat generator. + + :param messages: List of ChatMessages to generate function calling payload (e.g. human instructions). The last + message should be human instruction containing enough information to generate the function calling payload + suitable for the OpenAPI service/tool used. See the examples in the class docstring. + :param fc_generator_kwargs: Additional arguments for the function calling payload generation process. + :param spec: OpenAPI specification for the tool/service, overrides the one provided at initialization. + :param credentials: Credentials for the tool/service, overrides the one provided at initialization. + :returns: a dictionary containing the service response with the following key: + - `service_response`: List of ChatMessages containing the service response. ChatMessages are generated + based on the response from the OpenAPI service/tool and contains the JSON response from the service. + If there is an error during the invocation, the response will be a ChatMessage with the error message under + the `error` key. + """ + last_message = messages[-1] + if not last_message.is_from(ChatRole.USER): + raise ValueError(f"{last_message} not from the user") + if not last_message.content: + raise ValueError("Function calling instruction message content is empty.") + + # build a new ClientConfiguration and OpenAPIServiceClient if a runtime tool_spec is provided + openapi_service: Optional[OpenAPIServiceClient] = self.open_api_service + config_openapi: Optional[ClientConfiguration] = self.config_openapi + if spec: + if os.path.isfile(spec): + openapi_spec = OpenAPISpecification.from_file(spec) + elif is_valid_http_url(str(spec)): + openapi_spec = OpenAPISpecification.from_url(str(spec)) + else: + raise ValueError(f"Invalid OpenAPI specification source {spec}. Expected valid file path or URL") + + config_openapi = ClientConfiguration( + openapi_spec=openapi_spec, + credentials=credentials.resolve_value() if credentials else None, + llm_provider=self.generator_api, + ) + openapi_service = OpenAPIServiceClient(config_openapi) + + if not openapi_service or not config_openapi: + raise ValueError( + "OpenAPI specification not provided. Please provide an OpenAPI specification either at initialization " + "or during runtime." + ) + + # merge fc_generator_kwargs, tools definitions comes from the OpenAPI spec, other kwargs are passed by the user + fc_generator_kwargs = { + "tools": config_openapi.get_tools_definitions(), + **(fc_generator_kwargs or {}), + } + + # generate function calling payload with the chat generator + logger.debug( + "Invoking chat generator with {message} to generate function calling payload.", + message=last_message.content, + ) + fc_payload = self.chat_generator.run(messages, fc_generator_kwargs) + try: + invocation_payload = json.loads(fc_payload["replies"][0].content) + logger.debug("Invoking tool with {payload}", payload=invocation_payload) + service_response = openapi_service.invoke(invocation_payload) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error("Error invoking OpenAPI endpoint. Error: {e}", e=str(e)) + service_response = {"error": str(e)} + response_messages = [ChatMessage.from_user(json.dumps(service_response))] + + return {"service_response": response_messages} + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + + :returns: + The serialized component as a dictionary. + """ + serialize_secrets_inplace(self.generator_api_params, keys=["api_key"], recursive=True) + return default_to_dict( + self, + generator_api=self.generator_api.value, + generator_api_params=self.generator_api_params, + spec=self.spec, + credentials=self.credentials.to_dict() if self.credentials else None, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "OpenAPITool": + """ + Deserialize this component from a dictionary. + + :param data: The dictionary representation of this component. + :returns: + The deserialized component instance. + """ + deserialize_secrets_inplace(data["init_parameters"], keys=["credentials"]) + deserialize_secrets_inplace(data["init_parameters"]["generator_api_params"], keys=["api_key"]) + init_params = data.get("init_parameters", {}) + generator_api = init_params.get("generator_api") + data["init_parameters"]["generator_api"] = LLMProvider.from_str(generator_api) + return default_from_dict(cls, data) + + def _init_generator(self, generator_api: LLMProvider, generator_api_params: Dict[str, Any]): + """ + Initialize the chat generator based on the specified API provider and parameters. + """ + if generator_api == LLMProvider.OPENAI: + return OpenAIChatGenerator(**generator_api_params) + if generator_api == LLMProvider.COHERE: + cohere_import.check() + return CohereChatGenerator(**generator_api_params) + if generator_api == LLMProvider.ANTHROPIC: + anthropic_import.check() + return AnthropicChatGenerator(**generator_api_params) + raise ValueError(f"Unsupported generator API: {generator_api}") diff --git a/haystack_experimental/components/tools/openapi/types.py b/haystack_experimental/components/tools/openapi/types.py new file mode 100644 index 00000000..2562daa2 --- /dev/null +++ b/haystack_experimental/components/tools/openapi/types.py @@ -0,0 +1,256 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Union + +import requests +import yaml + +VALID_HTTP_METHODS = [ + "get", + "put", + "post", + "delete", + "options", + "head", + "patch", + "trace", +] + + +class LLMProvider(Enum): + """ + LLM providers supported by `OpenAPITool`. + """ + OPENAI = "openai" + ANTHROPIC = "anthropic" + COHERE = "cohere" + + @staticmethod + def from_str(string: str) -> "LLMProvider": + """ + Convert a string to a LLMProvider enum. + """ + provider_map = {e.value: e for e in LLMProvider} + provider = provider_map.get(string) + if provider is None: + msg = ( + f"Invalid LLMProvider '{string}'" + f"Supported LLMProviders are: {list(provider_map.keys())}" + ) + raise ValueError(msg) + return provider + + +@dataclass +class Operation: + """ + Represents an operation in an OpenAPI specification + + See https://spec.openapis.org/oas/latest.html#paths-object for details. + Path objects can contain multiple operations, each with a unique combination of path and method. + + :param path: Path of the operation. + :param method: HTTP method of the operation. + :param operation_dict: Operation details from OpenAPI spec + :param spec_dict: The encompassing OpenAPI specification. + :param security_requirements: A list of security requirements for the operation. + :param request_body: Request body details. + :param parameters: Parameters for the operation. + """ + + path: str + method: str + operation_dict: Dict[str, Any] + spec_dict: Dict[str, Any] + security_requirements: List[Dict[str, List[str]]] = field(init=False) + request_body: Dict[str, Any] = field(init=False) + parameters: List[Dict[str, Any]] = field(init=False) + + def __post_init__(self): + if self.method.lower() not in VALID_HTTP_METHODS: + raise ValueError(f"Invalid HTTP method: {self.method}") + self.method = self.method.lower() + self.security_requirements = self.operation_dict.get( + "security", [] + ) or self.spec_dict.get("security", []) + self.request_body = self.operation_dict.get("requestBody", {}) + self.parameters = self.operation_dict.get( + "parameters", [] + ) + self.spec_dict.get("paths", {}).get(self.path, {}).get("parameters", []) + + def get_parameters( + self, location: Optional[Literal["header", "query", "path"]] = None + ) -> List[Dict[str, Any]]: + """ + Get the parameters for the operation. + + :param location: The location of the parameters to get. + :returns: The parameters for the operation as a list of dictionaries. + """ + if location: + return [param for param in self.parameters if param["in"] == location] + return self.parameters + + def get_server(self, server_index: int = 0) -> str: + """ + Get the servers for the operation. + + :param server_index: The index of the server to use. + :returns: The server URL. + :raises ValueError: If no servers are found in the specification. + """ + servers = self.operation_dict.get("servers", []) or self.spec_dict.get( + "servers", [] + ) + if not servers: + raise ValueError("No servers found in the provided specification.") + if not 0 <= server_index < len(servers): + raise ValueError( + f"Server index {server_index} is out of bounds. " + f"Only {len(servers)} servers found." + ) + return servers[server_index].get( + "url" + ) # just use the first server from the list + + +class OpenAPISpecification: + """ + Represents an OpenAPI specification. See https://spec.openapis.org/oas/latest.html for details. + """ + + def __init__(self, spec_dict: Dict[str, Any]): + """ + Initialize an OpenAPISpecification instance. + + :param spec_dict: The OpenAPI specification as a dictionary. + """ + if not isinstance(spec_dict, Dict): + raise ValueError( + f"Invalid OpenAPI specification, expected a dictionary: {spec_dict}" + ) + # just a crude sanity check, by no means a full validation + if ( + "openapi" not in spec_dict + or "paths" not in spec_dict + or "servers" not in spec_dict + ): + raise ValueError( + "Invalid OpenAPI specification format. See https://swagger.io/specification/ for details.", + spec_dict, + ) + self.spec_dict = spec_dict + + @classmethod + def from_str(cls, content: str) -> "OpenAPISpecification": + """ + Create an OpenAPISpecification instance from a string. + + :param content: The string content of the OpenAPI specification. + :returns: The OpenAPISpecification instance. + :raises ValueError: If the content cannot be decoded as JSON or YAML. + """ + try: + loaded_spec = json.loads(content) + except json.JSONDecodeError: + try: + loaded_spec = yaml.safe_load(content) + except yaml.YAMLError as e: + raise ValueError( + "Content cannot be decoded as JSON or YAML: " + str(e) + ) from e + return cls(loaded_spec) + + @classmethod + def from_file(cls, spec_file: Union[str, Path]) -> "OpenAPISpecification": + """ + Create an OpenAPISpecification instance from a file. + + :param spec_file: The file path to the OpenAPI specification. + :returns: The OpenAPISpecification instance. + :raises FileNotFoundError: If the specified file does not exist. + :raises IOError: If an I/O error occurs while reading the file. + :raises ValueError: If the file content cannot be decoded as JSON or YAML. + """ + with open(spec_file, encoding="utf-8") as file: + content = file.read() + return cls.from_str(content) + + @classmethod + def from_url(cls, url: str) -> "OpenAPISpecification": + """ + Create an OpenAPISpecification instance from a URL. + + :param url: The URL to fetch the OpenAPI specification from. + :returns: The OpenAPISpecification instance. + :raises ConnectionError: If fetching the specification from the URL fails. + """ + try: + response = requests.get(url, timeout=10) + response.raise_for_status() + content = response.text + except requests.RequestException as e: + raise ConnectionError( + f"Failed to fetch the specification from URL: {url}. {e!s}" + ) from e + return cls.from_str(content) + + def find_operation_by_id( + self, op_id: str, method: Optional[str] = None + ) -> Operation: + """ + Find an Operation by operationId. + + :param op_id: The operationId of the operation. + :param method: The HTTP method of the operation. + :returns: The matching operation + :raises ValueError: If no operation is found with the given operationId. + """ + for path, path_item in self.spec_dict.get("paths", {}).items(): + op: Operation = self.get_operation_item(path, path_item, method) + if op_id in op.operation_dict.get("operationId", ""): + return self.get_operation_item(path, path_item, method) + raise ValueError( + f"No operation found with operationId {op_id}, method {method}" + ) + + def get_operation_item( + self, path: str, path_item: Dict[str, Any], method: Optional[str] = None + ) -> Operation: + """ + Gets a particular Operation item from the OpenAPI specification given the path and method. + + :param path: The path of the operation. + :param path_item: The path item from the OpenAPI specification. + :param method: The HTTP method of the operation. + :returns: The operation + """ + if method: + operation_dict = path_item.get(method.lower(), {}) + if not operation_dict: + raise ValueError( + f"No operation found for method {method} at path {path}" + ) + return Operation(path, method.lower(), operation_dict, self.spec_dict) + if len(path_item) == 1: + method, operation_dict = next(iter(path_item.items())) + return Operation(path, method, operation_dict, self.spec_dict) + if len(path_item) > 1: + raise ValueError( + f"Multiple operations found at path {path}, method parameter is required." + ) + raise ValueError(f"No operations found at path {path} and method {method}") + + def get_security_schemes(self) -> Dict[str, Dict[str, Any]]: + """ + Get the security schemes from the OpenAPI specification. + + :returns: The security schemes as a dictionary. + """ + return self.spec_dict.get("components", {}).get("securitySchemes", {}) diff --git a/haystack_experimental/util/__init__.py b/haystack_experimental/util/__init__.py new file mode 100644 index 00000000..032a10bc --- /dev/null +++ b/haystack_experimental/util/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from haystack_experimental.util.auth import serialize_secrets_inplace + +__all__ = ["serialize_secrets_inplace"] diff --git a/haystack_experimental/util/auth.py b/haystack_experimental/util/auth.py new file mode 100644 index 00000000..4db0d3ef --- /dev/null +++ b/haystack_experimental/util/auth.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, Iterable + +from haystack.utils import Secret + + +def serialize_secrets_inplace(data: Dict[str, Any], keys: Iterable[str], *, recursive: bool = False) -> None: + """ + Serialize secrets in a dictionary inplace. + + :param data: + The dictionary with the data containing secrets. + :param keys: + The keys of the secrets to serialize. + :param recursive: + Whether to recursively serialize nested dictionaries. + """ + for k, v in data.items(): + if isinstance(v, dict) and recursive: + serialize_secrets_inplace(v, keys, recursive=True) + elif k in keys and isinstance(v, Secret): + data[k] = v.to_dict() diff --git a/pyproject.toml b/pyproject.toml index eb261811..19d086ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ classifiers = [ "Programming Language :: Python :: 3.12", "Topic :: Scientific/Engineering :: Artificial Intelligence", ] -dependencies = ["haystack-ai"] +dependencies = ["jsonref", "haystack-ai"] [project.urls] "CI: GitHub" = "https://github.com/deepset-ai/haystack-experimental/actions" "GitHub: issues" = "https://github.com/deepset-ai/haystack-experimental/issues" @@ -39,6 +39,9 @@ dependencies = [ # Test "pytest", "pytest-cov", + "fastapi", + "cohere", + "anthropic", # Linting "pylint", "ruff", diff --git a/test/components/tools/openapi/__init__.py b/test/components/tools/openapi/__init__.py new file mode 100644 index 00000000..c1764a6e --- /dev/null +++ b/test/components/tools/openapi/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/test/components/tools/openapi/conftest.py b/test/components/tools/openapi/conftest.py new file mode 100644 index 00000000..224a29f0 --- /dev/null +++ b/test/components/tools/openapi/conftest.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +from pathlib import Path +from typing import Union +from urllib.parse import urlparse + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from haystack.utils.url_validation import is_valid_http_url + +from haystack_experimental.components.tools.openapi._openapi import HttpClientError +from haystack_experimental.components.tools.openapi.types import OpenAPISpecification + + +@pytest.fixture() +def test_files_path(): + return Path(__file__).parent.parent.parent.parent / "test_files" + + +def create_openapi_spec(openapi_spec: Union[Path, str]) -> OpenAPISpecification: + if isinstance(openapi_spec, (str, Path)) and os.path.isfile(openapi_spec): + return OpenAPISpecification.from_file(openapi_spec) + elif isinstance(openapi_spec, str): + if is_valid_http_url(openapi_spec): + return OpenAPISpecification.from_url(openapi_spec) + else: + return OpenAPISpecification.from_str(openapi_spec) + else: + raise ValueError( + "Invalid OpenAPI specification format. Expected file path or dictionary." + ) + + +class FastAPITestClient: + + def __init__(self, app: FastAPI): + self.app = app + self.client = TestClient(app) + + def strip_host(self, url: str) -> str: + parsed_url = urlparse(url) + new_path = parsed_url.path + if parsed_url.query: + new_path += "?" + parsed_url.query + return new_path + + def __call__(self, request: dict) -> dict: + # OAS spec will list a server URL, but FastAPI doesn't need it for local testing, in fact it will fail + # if the URL has a host. So we strip it here. + url = self.strip_host(request["url"]) + try: + response = self.client.request( + request["method"], + url, + headers=request.get("headers", {}), + params=request.get("params", {}), + json=request.get("json", None), + auth=request.get("auth", None), + cookies=request.get("cookies", {}), + ) + response.raise_for_status() + return response.json() + except Exception as e: + # Handle HTTP errors + raise HttpClientError(f"HTTP error occurred: {e}") from e diff --git a/test/components/tools/openapi/test_openapi_client.py b/test/components/tools/openapi/test_openapi_client.py new file mode 100644 index 00000000..ad745125 --- /dev/null +++ b/test/components/tools/openapi/test_openapi_client.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + + +from fastapi import FastAPI +from fastapi.responses import JSONResponse +from pydantic import BaseModel + +from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration +from test.components.tools.openapi.conftest import FastAPITestClient, create_openapi_spec + +""" +Tests OpenAPIServiceClient with three FastAPI apps for different parameter types: + +- **greet_mix_params_body**: A POST endpoint `/greet/` accepting a JSON payload with a message, returning a +greeting with the name from the URL and the message from the payload. + +- **greet_params_only**: A GET endpoint `/greet-params/` taking a URL parameter, returning a greeting with +the name from the URL. + +- **greet_request_body_only**: A POST endpoint `/greet-body` accepting a JSON payload with a name and message, +returning a greeting with both. + +OpenAPI specs for these endpoints are in `openapi_greeting_service.yml` in `test/test_files` directory. +""" + + +class GreetBody(BaseModel): + message: str + name: str + + +class MessageBody(BaseModel): + message: str + + +# FastAPI app definitions +def create_greet_mix_params_body_app() -> FastAPI: + app = FastAPI() + + @app.post("/greet/{name}") + def greet(name: str, body: MessageBody): + greeting = f"{body.message}, {name} from mix_params_body!" + return JSONResponse(content={"greeting": greeting}) + + return app + + +def create_greet_params_only_app() -> FastAPI: + app = FastAPI() + + @app.get("/greet-params/{name}") + def greet_params(name: str): + greeting = f"Hello, {name} from params_only!" + return JSONResponse(content={"greeting": greeting}) + + return app + + +def create_greet_request_body_only_app() -> FastAPI: + app = FastAPI() + + @app.post("/greet-body") + def greet_request_body(body: GreetBody): + greeting = f"{body.message}, {body.name} from request_body_only!" + return JSONResponse(content={"greeting": greeting}) + + return app + + +class TestOpenAPI: + + def test_greet_mix_params_body(self, test_files_path): + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml"), + request_sender=FastAPITestClient(create_greet_mix_params_body_app())) + client = OpenAPIServiceClient(config) + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": '{"name": "John", "message": "Bonjour"}', + "name": "greet", + }, + "type": "function", + } + response = client.invoke(payload) + assert response == {"greeting": "Bonjour, John from mix_params_body!"} + + def test_greet_params_only(self, test_files_path): + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml"), + request_sender=FastAPITestClient(create_greet_params_only_app())) + client = OpenAPIServiceClient(config) + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": '{"name": "John"}', + "name": "greetParams", + }, + "type": "function", + } + response = client.invoke(payload) + assert response == {"greeting": "Hello, John from params_only!"} + + def test_greet_request_body_only(self, test_files_path): + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml"), + request_sender=FastAPITestClient(create_greet_request_body_only_app())) + client = OpenAPIServiceClient(config) + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": '{"name": "John", "message": "Hola"}', + "name": "greetBody", + }, + "type": "function", + } + response = client.invoke(payload) + assert response == {"greeting": "Hola, John from request_body_only!"} diff --git a/test/components/tools/openapi/test_openapi_client_auth.py b/test/components/tools/openapi/test_openapi_client_auth.py new file mode 100644 index 00000000..6a91bcdd --- /dev/null +++ b/test/components/tools/openapi/test_openapi_client_auth.py @@ -0,0 +1,187 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + + +from fastapi import Depends, FastAPI, HTTPException, status +from fastapi.responses import JSONResponse +from fastapi.security import ( + APIKeyCookie, + APIKeyHeader, + APIKeyQuery, + HTTPAuthorizationCredentials, + HTTPBasic, + HTTPBasicCredentials, + HTTPBearer, +) + +from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration +from test.components.tools.openapi.conftest import FastAPITestClient, create_openapi_spec + +API_KEY = "secret_api_key" +BASIC_AUTH_USERNAME = "admin" +BASIC_AUTH_PASSWORD = "secret_password" + +API_KEY_QUERY = "secret_api_key_query" +API_KEY_COOKIE = "secret_api_key_cookie" +BEARER_TOKEN = "secret_bearer_token" + +OAUTH_TOKEN = "secret-oauth-token" + +api_key_query = APIKeyQuery(name="api_key") +api_key_cookie = APIKeyCookie(name="api_key") +bearer_auth = HTTPBearer() + +api_key_header = APIKeyHeader(name="X-API-Key") +basic_auth_http = HTTPBasic() + + +def create_greet_api_key_query_app() -> FastAPI: + app = FastAPI() + + def api_key_query_auth(api_key: str = Depends(api_key_query)): + if api_key != API_KEY_QUERY: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key") + return api_key + + @app.get("/greet-api-key-query/{name}") + def greet_api_key_query(name: str, api_key: str = Depends(api_key_query_auth)): + greeting = f"Hello, {name} from api_key_query_auth, using {api_key}" + return JSONResponse(content={"greeting": greeting}) + + return app + + +def create_greet_api_key_cookie_app() -> FastAPI: + app = FastAPI() + + def api_key_cookie_auth(api_key: str = Depends(api_key_cookie)): + if api_key != API_KEY_COOKIE: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key") + return api_key + + @app.get("/greet-api-key-cookie/{name}") + def greet_api_key_cookie(name: str, api_key: str = Depends(api_key_cookie_auth)): + greeting = f"Hello, {name} from api_key_cookie_auth, using {api_key}" + return JSONResponse(content={"greeting": greeting}) + + return app + + +def create_greet_bearer_auth_app() -> FastAPI: + app = FastAPI() + + def bearer_auth_scheme( + credentials: HTTPAuthorizationCredentials = Depends(bearer_auth), # noqa: B008 + ): + if credentials.scheme != "Bearer" or credentials.credentials != BEARER_TOKEN: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") + return credentials.credentials + + @app.get("/greet-bearer-auth/{name}") + def greet_bearer_auth(name: str, token: str = Depends(bearer_auth_scheme)): + greeting = f"Hello, {name} from bearer_auth, using {token}" + return JSONResponse(content={"greeting": greeting}) + + return app + + +def create_greet_api_key_auth_app() -> FastAPI: + app = FastAPI() + + def api_key_auth(api_key: str = Depends(api_key_header)): + if api_key != API_KEY: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key") + return api_key + + @app.get("/greet-api-key/{name}") + def greet_api_key(name: str, api_key: str = Depends(api_key_auth)): + greeting = f"Hello, {name} from api_key_auth, using {api_key}" + return JSONResponse(content={"greeting": greeting}) + + return app + + +def create_greet_basic_auth_app() -> FastAPI: + app = FastAPI() + + def basic_auth(credentials: HTTPBasicCredentials = Depends(basic_auth_http)): # noqa: B008 + if credentials.username != BASIC_AUTH_USERNAME or credentials.password != BASIC_AUTH_PASSWORD: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials") + return credentials.username + + @app.get("/greet-basic-auth/{name}") + def greet_basic_auth(name: str, username: str = Depends(basic_auth)): + greeting = f"Hello, {name} from basic_auth, using {username}" + return JSONResponse(content={"greeting": greeting}) + + return app + + +def create_greet_oauth_auth_app() -> FastAPI: + app = FastAPI() + + def oauth_auth(token: HTTPAuthorizationCredentials = Depends(HTTPBearer())): # noqa: B008 + if token.credentials != OAUTH_TOKEN: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") + return token + + @app.get("/greet-oauth/{name}") + def greet_oauth(name: str, token: HTTPAuthorizationCredentials = Depends(oauth_auth)): # noqa: B008 + greeting = f"Hello, {name} from oauth_auth, using {token}" + return JSONResponse(content={"greeting": greeting}) + + return app + + +class TestOpenAPIAuth: + + def test_greet_api_key_auth(self, test_files_path): + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml"), + request_sender=FastAPITestClient(create_greet_api_key_auth_app()), + credentials=API_KEY) + client = OpenAPIServiceClient(config) + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": '{"name": "John"}', + "name": "greetApiKey", + }, + "type": "function", + } + response = client.invoke(payload) + assert response == {"greeting": "Hello, John from api_key_auth, using secret_api_key"} + + def test_greet_api_key_query_auth(self, test_files_path): + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml"), + request_sender=FastAPITestClient(create_greet_api_key_query_app()), + credentials=API_KEY_QUERY) + client = OpenAPIServiceClient(config) + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": '{"name": "John"}', + "name": "greetApiKeyQuery", + }, + "type": "function", + } + response = client.invoke(payload) + assert response == {"greeting": "Hello, John from api_key_query_auth, using secret_api_key_query"} + + def test_greet_api_key_cookie_auth(self, test_files_path): + + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml"), + request_sender=FastAPITestClient(create_greet_api_key_cookie_app()), + credentials=API_KEY_COOKIE) + + client = OpenAPIServiceClient(config) + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": '{"name": "John"}', + "name": "greetApiKeyCookie", + }, + "type": "function", + } + response = client.invoke(payload) + assert response == {"greeting": "Hello, John from api_key_cookie_auth, using secret_api_key_cookie"} \ No newline at end of file diff --git a/test/components/tools/openapi/test_openapi_client_complex_request_body.py b/test/components/tools/openapi/test_openapi_client_complex_request_body.py new file mode 100644 index 00000000..8ce2273e --- /dev/null +++ b/test/components/tools/openapi/test_openapi_client_complex_request_body.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + + +import json +from typing import List + +import pytest +from fastapi import FastAPI +from fastapi.responses import JSONResponse +from pydantic import BaseModel + +from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration +from test.components.tools.openapi.conftest import FastAPITestClient, create_openapi_spec + + +class Customer(BaseModel): + name: str + email: str + + +class OrderItem(BaseModel): + product: str + quantity: int + + +class Order(BaseModel): + customer: Customer + items: List[OrderItem] + + +class OrderResponse(BaseModel): + orderId: str # noqa: N815 + status: str + totalAmount: float # noqa: N815 + + +def create_order_app() -> FastAPI: + app = FastAPI() + + @app.post("/orders") + def create_order(order: Order): + total_amount = sum(item.quantity * 10 for item in order.items) + response = OrderResponse( + orderId="ORDER-001", + status="CREATED", + totalAmount=total_amount, + ) + return JSONResponse(content=response.model_dump(), status_code=201) + + return app + + +class TestComplexRequestBody: + + @pytest.mark.parametrize("spec_file_path", ["openapi_order_service.yml", "openapi_order_service.json"]) + def test_create_order(self, spec_file_path, test_files_path): + path_element = "yaml" if spec_file_path.endswith(".yml") else "json" + + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / path_element / spec_file_path), + request_sender=FastAPITestClient(create_order_app())) + + client = OpenAPIServiceClient(config) + order_json = { + "customer": {"name": "John Doe", "email": "john@example.com"}, + "items": [ + {"product": "Product A", "quantity": 2}, + {"product": "Product B", "quantity": 1}, + ], + } + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": json.dumps(order_json), + "name": "createOrder", + }, + "type": "function", + } + response = client.invoke(payload) + assert response == { + "orderId": "ORDER-001", + "status": "CREATED", + "totalAmount": 30, + } diff --git a/test/components/tools/openapi/test_openapi_client_complex_request_body_mixed.py b/test/components/tools/openapi/test_openapi_client_complex_request_body_mixed.py new file mode 100644 index 00000000..9624f7d7 --- /dev/null +++ b/test/components/tools/openapi/test_openapi_client_complex_request_body_mixed.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + + +import json + +from fastapi import FastAPI +from fastapi.responses import JSONResponse +from pydantic import BaseModel + +from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration +from test.components.tools.openapi.conftest import FastAPITestClient, create_openapi_spec + + +class Identification(BaseModel): + type: str + number: str + + +class Payer(BaseModel): + name: str + email: str + identification: Identification + + +class PaymentRequest(BaseModel): + transaction_amount: float + description: str + payment_method_id: str + payer: Payer + + +class PaymentResponse(BaseModel): + transaction_id: str + status: str + message: str + + +def create_payment_app() -> FastAPI: + app = FastAPI() + + @app.post("/new_payment") + def process_payment(payment: PaymentRequest): + # sanity + assert payment.transaction_amount == 100.0 + response = PaymentResponse( + transaction_id="TRANS-12345", status="SUCCESS", message="Payment processed successfully." + ) + return JSONResponse(content=response.model_dump(), status_code=200) + + return app + + +# Write the unit test +class TestPaymentProcess: + + def test_process_payment(self, test_files_path): + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "json" / "complex_types_openapi_service.json"), + request_sender=FastAPITestClient(create_payment_app())) + client = OpenAPIServiceClient(config) + + payment_json = { + "transaction_amount": 100.0, + "description": "Test Payment", + "payment_method_id": "CARD-123", + "payer": { + "name": "Alice Smith", + "email": "alice@example.com", + "identification": {"type": "CPF", "number": "123.456.789-00"}, + }, + } + payload = { + "id": "call_uniqueID123", + "function": { + "arguments": json.dumps(payment_json), + "name": "processPayment", + }, + "type": "function", + } + response = client.invoke(payload) + assert response == { + "transaction_id": "TRANS-12345", + "status": "SUCCESS", + "message": "Payment processed successfully.", + } diff --git a/test/components/tools/openapi/test_openapi_client_edge_cases.py b/test/components/tools/openapi/test_openapi_client_edge_cases.py new file mode 100644 index 00000000..f6272baa --- /dev/null +++ b/test/components/tools/openapi/test_openapi_client_edge_cases.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + + +import pytest + +from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration +from test.components.tools.openapi.conftest import FastAPITestClient, create_openapi_spec + + +class TestEdgeCases: + + def test_missing_operation_id(self, test_files_path): + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "openapi_edge_cases.yml"), + request_sender=FastAPITestClient(None)) + client = OpenAPIServiceClient(config) + + payload = { + "type": "function", + "function": { + "arguments": '{"name": "John", "message": "Hola"}', + "name": "missingOperationId", + }, + } + with pytest.raises(ValueError, match="No operation found with operationId"): + client.invoke(payload) + + # TODO: Add more tests for edge cases diff --git a/test/components/tools/openapi/test_openapi_client_error_handling.py b/test/components/tools/openapi/test_openapi_client_error_handling.py new file mode 100644 index 00000000..c399c68d --- /dev/null +++ b/test/components/tools/openapi/test_openapi_client_error_handling.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + + +import json + +import pytest +from fastapi import FastAPI, HTTPException + +from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, HttpClientError, \ + ClientConfiguration +from test.components.tools.openapi.conftest import FastAPITestClient, create_openapi_spec + + +def create_error_handling_app() -> FastAPI: + app = FastAPI() + + @app.get("/error/{status_code}") + def raise_http_error(status_code: int): + raise HTTPException(status_code=status_code, detail=f"HTTP {status_code} error") + + return app + + +class TestErrorHandling: + @pytest.mark.parametrize("status_code", [400, 401, 403, 404, 500]) + def test_http_error_handling(self, test_files_path, status_code): + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "openapi_error_handling.yml"), + request_sender=FastAPITestClient(create_error_handling_app())) + client = OpenAPIServiceClient(config) + json_error = {"status_code": status_code} + payload = { + "type": "function", + "function": { + "arguments": json.dumps(json_error), + "name": "raiseHttpError", + }, + } + with pytest.raises(HttpClientError) as exc_info: + client.invoke(payload) + + assert str(status_code) in str(exc_info.value) diff --git a/test/components/tools/openapi/test_openapi_client_live.py b/test/components/tools/openapi/test_openapi_client_live.py new file mode 100644 index 00000000..3c3179d5 --- /dev/null +++ b/test/components/tools/openapi/test_openapi_client_live.py @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json +import os + +import pytest +import yaml +from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration +from test.components.tools.openapi.conftest import create_openapi_spec + + +class TestClientLive: + + @pytest.mark.skipif("SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set") + @pytest.mark.integration + def test_serperdev(self, test_files_path): + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "serper.yml"), credentials=os.getenv("SERPERDEV_API_KEY")) + serper_api = OpenAPIServiceClient(config) + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": '{"q": "Who was Nikola Tesla?"}', + "name": "serperdev_search", + }, + "type": "function", + } + response = serper_api.invoke(payload) + assert "invention" in str(response) + + @pytest.mark.integration + @pytest.mark.skip("This test hits rate limit on Github API. Skip for now.") + def test_github(self, test_files_path): + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "github_compare.yml")) + api = OpenAPIServiceClient(config) + + params = {"owner": "deepset-ai", "repo": "haystack", "basehead": "main...add_default_adapter_filters"} + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": json.dumps(params), + "name": "compare", + }, + "type": "function", + } + response = api.invoke(payload) + assert "deepset" in str(response) diff --git a/test/components/tools/openapi/test_openapi_client_live_anthropic.py b/test/components/tools/openapi/test_openapi_client_live_anthropic.py new file mode 100644 index 00000000..5467915a --- /dev/null +++ b/test/components/tools/openapi/test_openapi_client_live_anthropic.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os + +import anthropic +import pytest + +from haystack_experimental.components.tools.openapi._openapi import ClientConfiguration, OpenAPIServiceClient, \ + LLMProvider +from test.components.tools.openapi.conftest import create_openapi_spec + + +class TestClientLiveAnthropic: + + @pytest.mark.skipif("SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set") + @pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY not set") + @pytest.mark.integration + def test_serperdev(self, test_files_path): + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "serper.yml"), + credentials=os.getenv("SERPERDEV_API_KEY"), + llm_provider=LLMProvider.ANTHROPIC) + client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) + response = client.messages.create( + model="claude-3-opus-20240229", + max_tokens=1024, + tools=config.get_tools_definitions(), + messages=[{"role": "user", "content": "Do a google search: Who was Nikola Tesla?"}], + ) + service_api = OpenAPIServiceClient(config) + service_response = service_api.invoke(response) + assert "inventions" in str(service_response) + + # make a few more requests to test the same tool + service_response = service_api.invoke(response) + assert "Serbian" in str(service_response) + + service_response = service_api.invoke(response) + assert "American" in str(service_response) + + @pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY not set") + @pytest.mark.integration + def test_github(self, test_files_path): + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "github_compare.yml"), + llm_provider=LLMProvider.ANTHROPIC) + + client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) + response = client.messages.create( + model="claude-3-opus-20240229", + max_tokens=1024, + tools=config.get_tools_definitions(), + messages=[ + { + "role": "user", + "content": "Compare branches main and add_default_adapter_filters in repo" + " haystack and owner deepset-ai", + } + ], + ) + service_api = OpenAPIServiceClient(config) + service_response = service_api.invoke(response) + assert "deepset" in str(service_response) diff --git a/test/components/tools/openapi/test_openapi_client_live_cohere.py b/test/components/tools/openapi/test_openapi_client_live_cohere.py new file mode 100644 index 00000000..49e8cde6 --- /dev/null +++ b/test/components/tools/openapi/test_openapi_client_live_cohere.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os + +import cohere +import pytest + +from haystack_experimental.components.tools.openapi._openapi import ClientConfiguration, OpenAPIServiceClient, \ + LLMProvider +from test.components.tools.openapi.conftest import create_openapi_spec + +# Copied from Cohere's documentation +preamble = """ +## Task & Context +You help people answer their questions and other requests interactively. You will be asked a very wide array of + requests on all kinds of topics. You will be equipped with a wide range of search engines or similar tools to + help you, which you use to research your answer. You should focus on serving the user's needs as best you can, + which will be wide-ranging. + +## Style Guide +Unless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and + spelling. +""" + + +class TestClientLiveCohere: + + @pytest.mark.skipif("SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set") + @pytest.mark.skipif("COHERE_API_KEY" not in os.environ, reason="COHERE_API_KEY not set") + @pytest.mark.integration + def test_serperdev(self, test_files_path): + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "serper.yml"), + credentials=os.getenv("SERPERDEV_API_KEY"), + llm_provider=LLMProvider.COHERE) + client = cohere.Client(api_key=os.getenv("COHERE_API_KEY")) + response = client.chat( + model="command-r", + preamble=preamble, + tools=config.get_tools_definitions(), + message="Do a google search: Who was Nikola Tesla?", + ) + service_api = OpenAPIServiceClient(config) + service_response = service_api.invoke(response) + assert "inventions" in str(service_response) + + # make a few more requests to test the same tool + service_response = service_api.invoke(response) + assert "Serbian" in str(service_response) + + service_response = service_api.invoke(response) + assert "American" in str(service_response) + + @pytest.mark.skipif("COHERE_API_KEY" not in os.environ, reason="COHERE_API_KEY not set") + @pytest.mark.integration + def test_github(self, test_files_path): + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "github_compare.yml"), + llm_provider=LLMProvider.COHERE) + + client = cohere.Client(api_key=os.getenv("COHERE_API_KEY")) + response = client.chat( + model="command-r", + preamble=preamble, + tools=config.get_tools_definitions(), + message="Compare branches main and add_default_adapter_filters in repo haystack and owner deepset-ai", + ) + service_api = OpenAPIServiceClient(config) + service_response = service_api.invoke(response) + assert "deepset" in str(service_response) diff --git a/test/components/tools/openapi/test_openapi_client_live_openai.py b/test/components/tools/openapi/test_openapi_client_live_openai.py new file mode 100644 index 00000000..8be05a79 --- /dev/null +++ b/test/components/tools/openapi/test_openapi_client_live_openai.py @@ -0,0 +1,94 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +from openai import OpenAI + +from haystack_experimental.components.tools.openapi._openapi import ClientConfiguration, OpenAPIServiceClient +from test.components.tools.openapi.conftest import create_openapi_spec + + +class TestClientLiveOpenAPI: + + @pytest.mark.skipif("SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set") + @pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_serperdev(self, test_files_path): + + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "serper.yml"), + credentials=os.getenv("SERPERDEV_API_KEY")) + client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + response = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Do a serperdev google search: Who was Nikola Tesla?"}], + tools=config.get_tools_definitions(), + ) + service_api = OpenAPIServiceClient(config) + service_response = service_api.invoke(response) + assert "inventions" in str(service_response) + + # make a few more requests to test the same tool + service_response = service_api.invoke(response) + assert "Serbian" in str(service_response) + + service_response = service_api.invoke(response) + assert "American" in str(service_response) + + @pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set") + @pytest.mark.integration + @pytest.mark.skip("This test hits rate limit on Github API. Skip for now.") + def test_github(self, test_files_path): + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "github_compare.yml")) + client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + response = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": "Compare branches main and add_default_adapter_filters in repo" + " haystack and owner deepset-ai", + } + ], + tools=config.get_tools_definitions(), + ) + service_api = OpenAPIServiceClient(config) + service_response = service_api.invoke(response) + assert "deepset" in str(service_response) + + @pytest.mark.skipif("FIRECRAWL_API_KEY" not in os.environ, reason="FIRECRAWL_API_KEY not set") + @pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_firecrawl(self): + openapi_spec_url = "https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json" + config = ClientConfiguration(openapi_spec=create_openapi_spec(openapi_spec_url), credentials=os.getenv("FIRECRAWL_API_KEY")) + client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + response = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Scrape URL: https://news.ycombinator.com/"}], + tools=config.get_tools_definitions(), + ) + service_api = OpenAPIServiceClient(config) + service_response = service_api.invoke(response) + assert isinstance(service_response, dict) + assert service_response.get("success", False), "Firecrawl scrape API call failed" + + # now test the same openapi service but different endpoint/tool + top_k = 2 + response = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": f"Search Google for `Why was Sam Altman ousted from OpenAI?`, limit to {top_k} results", + } + ], + tools=config.get_tools_definitions(), + ) + service_response = service_api.invoke(response) + assert isinstance(service_response, dict) + assert service_response.get("success", False), "Firecrawl search API call failed" + assert len(service_response.get("data", [])) == top_k + assert "Sam" in str(service_response) diff --git a/test/components/tools/openapi/test_openapi_cohere_conversion.py b/test/components/tools/openapi/test_openapi_cohere_conversion.py new file mode 100644 index 00000000..ac8f42b2 --- /dev/null +++ b/test/components/tools/openapi/test_openapi_cohere_conversion.py @@ -0,0 +1,164 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from haystack_experimental.components.tools.openapi._openapi import OpenAPISpecification, cohere_converter + + +class TestOpenAPISchemaConversion: + + def test_serperdev(self, test_files_path): + spec = OpenAPISpecification.from_file(test_files_path / "yaml" / "serper.yml") + functions = cohere_converter(schema=spec) + + assert functions + assert len(functions) == 1 + function = functions[0] + assert function["name"] == "serperdev_search" + assert function["description"] == "Search the web with Google" + assert function["parameter_definitions"] == { + "q": {"description": "", "type": "str", "required": True} + } + + def test_firecrawler(self, test_files_path): + spec = OpenAPISpecification.from_file( + test_files_path / "json" / "firecrawl_openapi_spec.json" + ) + functions = cohere_converter(schema=spec) + assert functions + assert len(functions) == 5 + function = functions[0] + assert function["name"] == "scrapeAndExtractFromUrl" + assert ( + function["description"] + == "Scrape a single URL and optionally extract information using an LLM" + ) + assert function["parameter_definitions"] == { + "url": {"type": "str", "description": "The URL to scrape", "required": True}, + "pageOptions": { + "type": "object", + "description": "", + "required": False, + "properties": { + "onlyMainContent": { + "type": "bool", + "description": "Only return the main content of the page excluding headers, navs, footers, etc.", + "required": False, + }, + "includeHtml": { + "type": "bool", + "description": "Include the raw HTML content of the page. Will output a html key in the response.", + "required": False, + }, + "screenshot": { + "type": "bool", + "description": "Include a screenshot of the top of the page that you are scraping.", + "required": False, + }, + "waitFor": { + "type": "int", + "description": "Wait x amount of milliseconds for the page to load to fetch content", + "required": False, + }, + "removeTags": { + "type": "list", + "description": "Tags, classes and ids to remove from the page. Use comma separated values. Example: 'script, .ad, #footer'", + "required": False, + }, + "headers": { + "type": "object", + "description": "Headers to send with the request. Can be used to send cookies, user-agent, etc.", + "properties": {}, + "required": False, + }, + }, + }, + "extractorOptions": { + "type": "object", + "description": "Options for LLM-based extraction of structured information from the page content", + "required": False, + "properties": { + "mode": { + "type": "str", + "description": "The extraction mode to use, currently supports 'llm-extraction'", + "required": False, + }, + "extractionPrompt": { + "type": "str", + "description": "A prompt describing what information to extract from the page", + "required": False, + }, + "extractionSchema": { + "type": "object", + "description": "The schema for the data to be extracted", + "properties": {}, + "required": False, + }, + }, + }, + "timeout": { + "type": "int", + "description": "Timeout in milliseconds for the request", + "required": False, + }, + } + + def test_github(self, test_files_path): + spec = OpenAPISpecification.from_file(test_files_path / "yaml" / "github_compare.yml") + functions = cohere_converter(schema=spec) + assert functions + assert len(functions) == 1 + function = functions[0] + assert function["name"] == "compare_branches" + assert function["description"] == "Compares two branches against one another." + assert function["parameter_definitions"] == { + "basehead": { + "description": "The base branch and head branch to compare." + " This parameter expects the format `BASE...HEAD`", + "type": "str", + "required": True, + }, + "owner": { + "description": "The repository owner, usually a company or orgnization", + "type": "str", + "required": True, + }, + "repo": {"description": "The repository itself, the project", "type": "str", "required": True}, + } + + def test_complex_types(self, test_files_path): + spec = OpenAPISpecification.from_file(test_files_path / "json" / "complex_types_openapi_service.json") + functions = cohere_converter(schema=spec) + + assert functions + assert len(functions) == 1 + function = functions[0] + assert function["name"] == "processPayment" + assert function["description"] == "Process a new payment using the specified payment method" + assert function["parameter_definitions"] == { + "transaction_amount": {"type": "float", "description": "The amount to be paid", "required": True}, + "description": {"type": "str", "description": "A brief description of the payment", "required": True}, + "payment_method_id": {"type": "str", "description": "The payment method to be used", "required": True}, + "payer": { + "type": "object", + "description": "Information about the payer, including their name, email, and identification number", + "properties": { + "name": {"type": "str", "description": "The payer's name", "required": True}, + "email": {"type": "str", "description": "The payer's email address", "required": True}, + "identification": { + "type": "object", + "description": "The payer's identification number", + "properties": { + "type": { + "type": "str", + "description": "The type of identification document (e.g., CPF, CNPJ)", + "required": True, + }, + "number": {"type": "str", "description": "The identification number", "required": True}, + }, + "required": True, + }, + }, + "required": True, + }, + } diff --git a/test/components/tools/openapi/test_openapi_openai_conversion.py b/test/components/tools/openapi/test_openapi_openai_conversion.py new file mode 100644 index 00000000..9e7285dc --- /dev/null +++ b/test/components/tools/openapi/test_openapi_openai_conversion.py @@ -0,0 +1,108 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from haystack_experimental.components.tools.openapi._openapi import ( + openai_converter, + anthropic_converter, + OpenAPISpecification, +) + + +class TestOpenAPISchemaConversion: + + @pytest.mark.parametrize("provider", ["openai", "anthropic"]) + def test_serperdev(self, test_files_path, provider): + spec = OpenAPISpecification.from_file(test_files_path / "yaml" / "serper.yml") + functions = openai_converter(schema=spec) if provider == "openai" else anthropic_converter(schema=spec) + assert functions + assert len(functions) == 1 + function = functions[0]["function"] if provider == "openai" else functions[0] + assert function["name"] == "serperdev_search" + assert function["description"] == "Search the web with Google" + assert ( + function["parameters"] + if provider == "openai" + else function["input_schema"] + == {"type": "object", "properties": {"q": {"type": "string"}}, "required": ["q"]} + ) + + @pytest.mark.parametrize("provider", ["openai", "anthropic"]) + def test_github(self, test_files_path, provider: str): + spec = OpenAPISpecification.from_file(test_files_path / "yaml" / "github_compare.yml") + functions = openai_converter(schema=spec) if provider == "openai" else anthropic_converter(schema=spec) + assert functions + assert len(functions) == 1 + function = functions[0]["function"] if provider == "openai" else functions[0] + assert function["name"] == "compare_branches" + assert function["description"] == "Compares two branches against one another." + assert ( + function["parameters"] + if provider == "openai" + else function["input_schema"] + == { + "type": "object", + "properties": { + "basehead": { + "type": "string", + "description": "The base branch and head branch to compare. " + "This parameter expects the format `BASE...HEAD`", + }, + "owner": { + "type": "string", + "description": "The repository owner, usually a company or orgnization", + }, + "repo": {"type": "string", "description": "The repository itself, the project"}, + }, + "required": ["basehead", "owner", "repo"], + } + ) + + @pytest.mark.parametrize("provider", ["openai", "anthropic"]) + def test_complex_types(self, test_files_path, provider: str): + spec = OpenAPISpecification.from_file(test_files_path / "json" / "complex_types_openapi_service.json") + functions = openai_converter(schema=spec) if provider == "openai" else anthropic_converter(schema=spec) + + assert functions + assert len(functions) == 1 + function = functions[0]["function"] if provider == "openai" else functions[0] + assert function["name"] == "processPayment" + assert function["description"] == "Process a new payment using the specified payment method" + assert ( + function["parameters"] + if provider == "openai" + else function["input_schema"] + == { + "type": "object", + "properties": { + "transaction_amount": {"type": "number", "description": "The amount to be paid"}, + "description": {"type": "string", "description": "A brief description of the payment"}, + "payment_method_id": {"type": "string", "description": "The payment method to be used"}, + "payer": { + "type": "object", + "description": "Information about the payer, including their name, email, " + "and identification number", + "properties": { + "name": {"type": "string", "description": "The payer's name"}, + "email": {"type": "string", "description": "The payer's email address"}, + "identification": { + "type": "object", + "description": "The payer's identification number", + "properties": { + "type": { + "type": "string", + "description": "The type of identification document (e.g., CPF, CNPJ)", + }, + "number": {"type": "string", "description": "The identification number"}, + }, + "required": ["type", "number"], + }, + }, + "required": ["name", "email", "identification"], + }, + }, + "required": ["transaction_amount", "description", "payment_method_id", "payer"], + } + ) diff --git a/test/components/tools/openapi/test_openapi_spec.py b/test/components/tools/openapi/test_openapi_spec.py new file mode 100644 index 00000000..4e38de2d --- /dev/null +++ b/test/components/tools/openapi/test_openapi_spec.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from haystack_experimental.components.tools.openapi._openapi import OpenAPISpecification + + +class TestOpenAPISpecification: + + # can be initialized from a string + def test_initialized_from_string(self): + content = """ + openapi: 3.0.0 + info: + title: Test API + version: 1.0.0 + servers: + - url: https://api.example.com + paths: + /users: + get: + summary: Get all users + responses: + '200': + description: Successful response + """ + openapi_spec = OpenAPISpecification.from_str(content) + assert openapi_spec.spec_dict == { + "openapi": "3.0.0", + "info": {"title": "Test API", "version": "1.0.0"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/users": { + "get": {"summary": "Get all users", "responses": {"200": {"description": "Successful response"}}} + } + }, + } + + # can be initialized from a file + def test_initialized_from_file(self, tmp_path): + content = """ + openapi: 3.0.0 + info: + title: Test API + version: 1.0.0 + servers: + - url: https://api.example.com + paths: + /users: + get: + summary: Get all users + responses: + '200': + description: Successful response + """ + file_path = tmp_path / "spec.yaml" + file_path.write_text(content) + openapi_spec = OpenAPISpecification.from_file(file_path) + assert openapi_spec.spec_dict == { + "openapi": "3.0.0", + "info": {"title": "Test API", "version": "1.0.0"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/users": { + "get": {"summary": "Get all users", "responses": {"200": {"description": "Successful response"}}} + } + }, + } + + # raises ValueError if initialized from an invalid schema + def test_raises_value_error_invalid_schema(self): + spec_dict = {"info": {"title": "Test API", "version": "1.0.0"}, "paths": {"/users": {}}} + with pytest.raises(ValueError): + OpenAPISpecification(spec_dict) diff --git a/test/components/tools/openapi/test_openapi_tool.py b/test/components/tools/openapi/test_openapi_tool.py new file mode 100644 index 00000000..5119ba61 --- /dev/null +++ b/test/components/tools/openapi/test_openapi_tool.py @@ -0,0 +1,203 @@ +import json +import os + +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.dataclasses import ChatMessage +from haystack.utils import Secret + +from haystack_experimental.components.tools.openapi import LLMProvider +from haystack_experimental.components.tools.openapi.openapi_tool import OpenAPITool + +import pytest + + +class TestOpenAPITool: + + def test_to_dict(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") + monkeypatch.setenv("SERPERDEV_API_KEY", "fake-api-key") + + openapi_spec_url = "https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json" + + tool = OpenAPITool( + generator_api=LLMProvider.OPENAI, + generator_api_params={ + "model": "gpt-3.5-turbo", + "api_key": Secret.from_env_var("OPENAI_API_KEY"), + }, + spec=openapi_spec_url, + credentials=Secret.from_env_var("SERPERDEV_API_KEY"), + ) + + data = tool.to_dict() + assert data == { + "type": "haystack_experimental.components.tools.openapi.openapi_tool.OpenAPITool", + "init_parameters": { + "generator_api": "openai", + "generator_api_params": { + "model": "gpt-3.5-turbo", + "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"}, + }, + "spec": openapi_spec_url, + "credentials": {"env_vars": ["SERPERDEV_API_KEY"], "strict": True, "type": "env_var"}, + }, + } + + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") + monkeypatch.setenv("SERPERDEV_API_KEY", "fake-api-key") + openapi_spec_url = "https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json" + data = { + "type": "haystack_experimental.components.tools.openapi.openapi_tool.OpenAPITool", + "init_parameters": { + "generator_api": "openai", + "generator_api_params": { + "model": "gpt-3.5-turbo", + "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"}, + }, + "spec": openapi_spec_url, + "credentials": {"env_vars": ["SERPERDEV_API_KEY"], "strict": True, "type": "env_var"}, + }, + } + + tool = OpenAPITool.from_dict(data) + + assert tool.generator_api == LLMProvider.OPENAI + assert tool.generator_api_params == { + "model": "gpt-3.5-turbo", + "api_key": Secret.from_env_var("OPENAI_API_KEY") + } + assert tool.spec == openapi_spec_url + assert tool.credentials == Secret.from_env_var("SERPERDEV_API_KEY") + + def test_initialize_with_invalid_openapi_spec_url(self): + with pytest.raises(ConnectionError, match="Failed to fetch the specification from URL"): + OpenAPITool( + generator_api=LLMProvider.OPENAI, + generator_api_params={ + "model": "gpt-3.5-turbo", + "api_key": Secret.from_token("not_needed"), + }, + spec="https://raw.githubusercontent.com/invalid_openapi.json", + ) + + def test_initialize_with_invalid_openapi_spec_path(self): + with pytest.raises(ValueError, match="Invalid OpenAPI specification source"): + OpenAPITool( + generator_api=LLMProvider.OPENAI, + generator_api_params={ + "model": "gpt-3.5-turbo", + "api_key": Secret.from_token("not_needed"), + }, + spec="invalid_openapi.json", + ) + + def test_initialize_with_valid_openapi_spec_url_and_credentials(self): + openapi_spec_url = "https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json" + credentials = Secret.from_token("") + tool = OpenAPITool( + generator_api=LLMProvider.OPENAI, + generator_api_params={ + "model": "gpt-3.5-turbo", + "api_key": Secret.from_token("not_needed"), + }, + spec=openapi_spec_url, + credentials=credentials, + ) + + assert tool.generator_api == LLMProvider.OPENAI + assert isinstance(tool.chat_generator, OpenAIChatGenerator) + assert tool.config_openapi is not None + assert tool.open_api_service is not None + + @pytest.mark.skipif( + "SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set" + ) + @pytest.mark.skipif( + "OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set" + ) + @pytest.mark.integration + def test_run_live_openai(self): + tool = OpenAPITool( + generator_api=LLMProvider.OPENAI, + spec="https://bit.ly/serper_dev_spec_yaml", + credentials=Secret.from_env_var("SERPERDEV_API_KEY"), + ) + + user_message = ChatMessage.from_user( + "Search for 'Who was Nikola Tesla?'" + ) + + results = tool.run(messages=[user_message]) + + assert isinstance(results["service_response"], list) + assert len(results["service_response"]) == 1 + assert isinstance(results["service_response"][0], ChatMessage) + + try: + json_response = json.loads(results["service_response"][0].content) + assert isinstance(json_response, dict) + except json.JSONDecodeError: + pytest.fail("Response content is not valid JSON") + + @pytest.mark.skipif( + "SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set" + ) + @pytest.mark.skipif( + "ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY not set" + ) + @pytest.mark.integration + def test_run_live_anthropic(self): + tool = OpenAPITool( + generator_api=LLMProvider.ANTHROPIC, + generator_api_params={"model": "claude-3-opus-20240229"}, + spec="https://bit.ly/serper_dev_spec_yaml", + credentials=Secret.from_env_var("SERPERDEV_API_KEY"), + ) + + user_message = ChatMessage.from_user( + "Search for 'Who was Nikola Tesla?'" + ) + + results = tool.run(messages=[user_message]) + + assert isinstance(results["service_response"], list) + assert len(results["service_response"]) == 1 + assert isinstance(results["service_response"][0], ChatMessage) + + try: + json_response = json.loads(results["service_response"][0].content) + assert isinstance(json_response, dict) + except json.JSONDecodeError: + pytest.fail("Response content is not valid JSON") + + @pytest.mark.skipif( + "SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set" + ) + @pytest.mark.skipif( + "COHERE_API_KEY" not in os.environ, reason="COHERE_API_KEY not set" + ) + @pytest.mark.integration + def test_run_live_cohere(self): + tool = OpenAPITool( + generator_api=LLMProvider.COHERE, + generator_api_params={"model": "command-r"}, + spec="https://bit.ly/serper_dev_spec_yaml", + credentials=Secret.from_env_var("SERPERDEV_API_KEY"), + ) + + user_message = ChatMessage.from_user( + "Search for 'Who was Nikola Tesla?'" + ) + + results = tool.run(messages=[user_message]) + + assert isinstance(results["service_response"], list) + assert len(results["service_response"]) == 1 + assert isinstance(results["service_response"][0], ChatMessage) + + try: + json_response = json.loads(results["service_response"][0].content) + assert isinstance(json_response, dict) + except json.JSONDecodeError: + pytest.fail("Response content is not valid JSON") diff --git a/test/test_files/json/complex_types_openai_spec.json b/test/test_files/json/complex_types_openai_spec.json new file mode 100644 index 00000000..ebaf9556 --- /dev/null +++ b/test/test_files/json/complex_types_openai_spec.json @@ -0,0 +1,64 @@ +{ + "name": "processPayment", + "description": "Process a new payment using the specified payment method", + "parameters": { + "type": "object", + "properties": { + "transaction_amount": { + "type": "number", + "description": "The amount to be paid" + }, + "description": { + "type": "string", + "description": "A brief description of the payment" + }, + "payment_method_id": { + "type": "string", + "description": "The payment method to be used" + }, + "payer": { + "type": "object", + "description": "Information about the payer, including their name, email, and identification number", + "properties": { + "name": { + "type": "string", + "description": "The payer's name" + }, + "email": { + "type": "string", + "description": "The payer's email address" + }, + "identification": { + "type": "object", + "description": "The payer's identification number", + "properties": { + "type": { + "type": "string", + "description": "The type of identification document (e.g., CPF, CNPJ)" + }, + "number": { + "type": "string", + "description": "The identification number" + } + }, + "required": [ + "type", + "number" + ] + } + }, + "required": [ + "name", + "email", + "identification" + ] + } + }, + "required": [ + "transaction_amount", + "description", + "payment_method_id", + "payer" + ] + } +} diff --git a/test/test_files/json/complex_types_openapi_service.json b/test/test_files/json/complex_types_openapi_service.json new file mode 100644 index 00000000..3ea04f8a --- /dev/null +++ b/test/test_files/json/complex_types_openapi_service.json @@ -0,0 +1,103 @@ +{ + "openapi": "3.0.0", + "info": { + "title": "Payment API", + "version": "1.0.0" + }, + "servers": [ + { + "url": "http://localhost:8080" + } + ], + "paths": { + "/new_payment": { + "post": { + "summary": "Process a new payment", + "description": "Process a new payment using the specified payment method", + "operationId": "processPayment", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "transaction_amount": { + "type": "number", + "description": "The amount to be paid" + }, + "description": { + "type": "string", + "description": "A brief description of the payment" + }, + "payment_method_id": { + "type": "string", + "description": "The payment method to be used" + }, + "payer": { + "$ref": "#/components/schemas/Payer" + } + }, + "required": [ + "transaction_amount", + "description", + "payment_method_id", + "payer" + ] + } + } + } + }, + "responses": { + "200": { + "description": "Payment processed successfully" + }, + "400": { + "description": "Invalid request" + } + } + } + } + }, + "components": { + "schemas": { + "Payer": { + "type": "object", + "description": "Information about the payer, including their name, email, and identification number", + "properties": { + "name": { + "type": "string", + "description": "The payer's name" + }, + "email": { + "type": "string", + "description": "The payer's email address" + }, + "identification": { + "type": "object", + "description": "The payer's identification number", + "properties": { + "type": { + "type": "string", + "description": "The type of identification document (e.g., CPF, CNPJ)" + }, + "number": { + "type": "string", + "description": "The identification number" + } + }, + "required": [ + "type", + "number" + ] + } + }, + "required": [ + "name", + "email", + "identification" + ] + } + } + } +} diff --git a/test/test_files/json/firecrawl_openapi_spec.json b/test/test_files/json/firecrawl_openapi_spec.json new file mode 100644 index 00000000..e5868571 --- /dev/null +++ b/test/test_files/json/firecrawl_openapi_spec.json @@ -0,0 +1,881 @@ +{ + "openapi": "3.0.0", + "info": { + "title": "Firecrawl API", + "version": "1.0.0", + "description": "API for interacting with Firecrawl services to perform web scraping and crawling tasks.", + "contact": { + "name": "Firecrawl Support", + "url": "https://firecrawl.dev/support", + "email": "support@firecrawl.dev" + } + }, + "servers": [ + { + "url": "https://api.firecrawl.dev/v0" + } + ], + "paths": { + "/scrape": { + "post": { + "summary": "Scrape a single URL and optionally extract information using an LLM", + "operationId": "scrapeAndExtractFromUrl", + "tags": ["Scraping"], + "security": [ + { + "bearerAuth": [] + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "url": { + "type": "string", + "format": "uri", + "description": "The URL to scrape" + }, + "pageOptions": { + "type": "object", + "properties": { + "onlyMainContent": { + "type": "boolean", + "description": "Only return the main content of the page excluding headers, navs, footers, etc.", + "default": false + }, + "includeHtml": { + "type": "boolean", + "description": "Include the raw HTML content of the page. Will output a html key in the response.", + "default": false + }, + "screenshot": { + "type": "boolean", + "description": "Include a screenshot of the top of the page that you are scraping.", + "default": false + }, + "waitFor": { + "type": "integer", + "description": "Wait x amount of milliseconds for the page to load to fetch content", + "default": 0 + }, + "removeTags": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Tags, classes and ids to remove from the page. Use comma separated values. Example: 'script, .ad, #footer'" + }, + "headers": { + "type": "object", + "description": "Headers to send with the request. Can be used to send cookies, user-agent, etc." + } + } + }, + "extractorOptions": { + "type": "object", + "description": "Options for LLM-based extraction of structured information from the page content", + "properties": { + "mode": { + "type": "string", + "enum": ["llm-extraction"], + "description": "The extraction mode to use, currently supports 'llm-extraction'" + }, + "extractionPrompt": { + "type": "string", + "description": "A prompt describing what information to extract from the page" + }, + "extractionSchema": { + "type": "object", + "additionalProperties": true, + "description": "The schema for the data to be extracted", + "required": [ + "company_mission", + "supports_sso", + "is_open_source" + ] + } + } + }, + "timeout": { + "type": "integer", + "description": "Timeout in milliseconds for the request", + "default": 30000 + } + }, + "required": ["url"] + } + } + } + }, + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ScrapeResponse" + } + } + } + }, + "402": { + "description": "Payment required" + }, + "429": { + "description": "Too many requests" + }, + "500": { + "description": "Server error" + } + } + } + }, + "/crawl": { + "post": { + "summary": "Crawl multiple URLs based on options", + "operationId": "crawlUrls", + "tags": ["Crawling"], + "security": [ + { + "bearerAuth": [] + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "url": { + "type": "string", + "format": "uri", + "description": "The base URL to start crawling from" + }, + "crawlerOptions": { + "type": "object", + "properties": { + "includes": { + "type": "array", + "items": { + "type": "string" + }, + "description": "URL patterns to include" + }, + "excludes": { + "type": "array", + "items": { + "type": "string" + }, + "description": "URL patterns to exclude" + }, + "generateImgAltText": { + "type": "boolean", + "description": "Generate alt text for images using LLMs (must have a paid plan)", + "default": false + }, + "returnOnlyUrls": { + "type": "boolean", + "description": "If true, returns only the URLs as a list on the crawl status. Attention: the return response will be a list of URLs inside the data, not a list of documents.", + "default": false + }, + "maxDepth": { + "type": "integer", + "description": "Maximum depth to crawl. Depth 1 is the base URL, depth 2 is the base URL and its direct children, and so on." + }, + "mode": { + "type": "string", + "enum": ["default", "fast"], + "description": "The crawling mode to use. Fast mode crawls 4x faster websites without sitemap, but may not be as accurate and shouldn't be used in heavy js-rendered websites.", + "default": "default" + }, + "ignoreSitemap": { + "type": "boolean", + "description": "Ignore the website sitemap when crawling", + "default": false + }, + "limit": { + "type": "integer", + "description": "Maximum number of pages to crawl", + "default": 10000 + }, + "allowBackwardCrawling": { + "type": "boolean", + "description": "Allow backward crawling (crawl from the base URL to the previous URLs)", + "default": false + } + } + }, + "pageOptions": { + "type": "object", + "properties": { + "onlyMainContent": { + "type": "boolean", + "description": "Only return the main content of the page excluding headers, navs, footers, etc.", + "default": false + }, + "includeHtml": { + "type": "boolean", + "description": "Include the raw HTML content of the page. Will output a html key in the response.", + "default": false + }, + "screenshot": { + "type": "boolean", + "description": "Include a screenshot of the top of the page that you are scraping.", + "default": false + }, + "headers": { + "type": "object", + "description": "Headers to send with the request when scraping. Can be used to send cookies, user-agent, etc." + }, + "removeTags": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Tags, classes and ids to remove from the page. Use comma separated values. Example: 'script, .ad, #footer'" + }, + "replaceAllPathsWithAbsolutePaths": { + "type": "boolean", + "description": "Replace all relative paths with absolute paths for images and links", + "default": false + } + } + } + }, + "required": ["url"] + } + } + } + }, + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CrawlResponse" + } + } + } + }, + "402": { + "description": "Payment required" + }, + "429": { + "description": "Too many requests" + }, + "500": { + "description": "Server error" + } + } + } + }, + "/search": { + "post": { + "summary": "Search for a keyword in Google, returns top page results with markdown content for each page", + "operationId": "searchGoogle", + "tags": ["Search"], + "security": [ + { + "bearerAuth": [] + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "query": { + "type": "string", + "format": "uri", + "description": "The query to search for" + }, + "pageOptions": { + "type": "object", + "properties": { + "onlyMainContent": { + "type": "boolean", + "description": "Only return the main content of the page excluding headers, navs, footers, etc.", + "default": false + }, + "fetchPageContent": { + "type": "boolean", + "description": "Fetch the content of each page. If false, defaults to a basic fast serp API.", + "default": true + }, + "includeHtml": { + "type": "boolean", + "description": "Include the raw HTML content of the page. Will output a html key in the response.", + "default": false + } + } + }, + "searchOptions": { + "type": "object", + "properties": { + "limit": { + "type": "integer", + "description": "Maximum number of results. Max is 20 during beta." + } + } + } + }, + "required": ["query"] + } + } + } + }, + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SearchResponse" + } + } + } + }, + "402": { + "description": "Payment required" + }, + "429": { + "description": "Too many requests" + }, + "500": { + "description": "Server error" + } + } + } + }, + "/crawl/status/{jobId}": { + "get": { + "tags": ["Crawl"], + "summary": "Get the status of a crawl job", + "operationId": "getCrawlStatus", + "security": [ + { + "bearerAuth": [] + } + ], + "parameters": [ + { + "name": "jobId", + "in": "path", + "description": "ID of the crawl job", + "required": true, + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "status": { + "type": "string", + "description": "Status of the job (completed, active, failed, paused)" + }, + "current": { + "type": "integer", + "description": "Current page number" + }, + "current_url": { + "type": "string", + "description": "Current URL being scraped" + }, + "current_step": { + "type": "string", + "description": "Current step in the process" + }, + "total": { + "type": "integer", + "description": "Total number of pages" + }, + "data": { + "type": "array", + "items": { + "$ref": "#/components/schemas/CrawlStatusResponseObj" + }, + "description": "Data returned from the job (null when it is in progress)" + }, + "partial_data": { + "type": "array", + "items": { + "$ref": "#/components/schemas/CrawlStatusResponseObj" + }, + "description": "Partial documents returned as it is being crawled (streaming). **This feature is currently in alpha - expect breaking changes** When a page is ready, it will append to the partial_data array, so there is no need to wait for the entire website to be crawled. There is a max of 50 items in the array response. The oldest item (top of the array) will be removed when the new item is added to the array." + } + } + } + } + } + }, + "402": { + "description": "Payment required" + }, + "429": { + "description": "Too many requests" + }, + "500": { + "description": "Server error" + } + } + } + }, + "/crawl/cancel/{jobId}": { + "delete": { + "tags": ["Crawl"], + "summary": "Cancel a crawl job", + "operationId": "cancelCrawlJob", + "security": [ + { + "bearerAuth": [] + } + ], + "parameters": [ + { + "name": "jobId", + "in": "path", + "description": "ID of the crawl job", + "required": true, + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "status": { + "type": "string", + "description": "Returns cancelled." + } + } + } + } + } + }, + "402": { + "description": "Payment required" + }, + "429": { + "description": "Too many requests" + }, + "500": { + "description": "Server error" + } + } + } + } + }, + "components": { + "securitySchemes": { + "bearerAuth": { + "type": "http", + "scheme": "bearer" + } + }, + "schemas": { + "ScrapeResponse": { + "type": "object", + "properties": { + "success": { + "type": "boolean" + }, + "data": { + "type": "object", + "properties": { + "markdown": { + "type": "string" + }, + "content": { + "type": "string" + }, + "html": { + "type": "string", + "nullable": true, + "description": "Raw HTML content of the page if `includeHtml` is true" + }, + "metadata": { + "type": "object", + "properties": { + "title": { + "type": "string" + }, + "description": { + "type": "string" + }, + "language": { + "type": "string", + "nullable": true + }, + "keywords": { + "type": "string", + "nullable": true + }, + "robots": { + "type": "string", + "nullable": true + }, + "ogTitle": { + "type": "string", + "nullable": true + }, + "ogDescription": { + "type": "string", + "nullable": true + }, + "ogUrl": { + "type": "string", + "format": "uri", + "nullable": true + }, + "ogImage": { + "type": "string", + "nullable": true + }, + "ogAudio": { + "type": "string", + "nullable": true + }, + "ogDeterminer": { + "type": "string", + "nullable": true + }, + "ogLocale": { + "type": "string", + "nullable": true + }, + "ogLocaleAlternate": { + "type": "array", + "items": { + "type": "string" + }, + "nullable": true + }, + "ogSiteName": { + "type": "string", + "nullable": true + }, + "ogVideo": { + "type": "string", + "nullable": true + }, + "dctermsCreated": { + "type": "string", + "nullable": true + }, + "dcDateCreated": { + "type": "string", + "nullable": true + }, + "dcDate": { + "type": "string", + "nullable": true + }, + "dctermsType": { + "type": "string", + "nullable": true + }, + "dcType": { + "type": "string", + "nullable": true + }, + "dctermsAudience": { + "type": "string", + "nullable": true + }, + "dctermsSubject": { + "type": "string", + "nullable": true + }, + "dcSubject": { + "type": "string", + "nullable": true + }, + "dcDescription": { + "type": "string", + "nullable": true + }, + "dctermsKeywords": { + "type": "string", + "nullable": true + }, + "modifiedTime": { + "type": "string", + "nullable": true + }, + "publishedTime": { + "type": "string", + "nullable": true + }, + "articleTag": { + "type": "string", + "nullable": true + }, + "articleSection": { + "type": "string", + "nullable": true + }, + "sourceURL": { + "type": "string", + "format": "uri" + }, + "pageStatusCode": { + "type": "integer", + "description": "The status code of the page" + }, + "pageError": { + "type": "string", + "nullable": true, + "description": "The error message of the page" + } + } + }, + "llm_extraction": { + "type": "object", + "description": "Displayed when using LLM Extraction. Extracted data from the page following the schema defined.", + "nullable": true + }, + "warning": { + "type": "string", + "nullable": true, + "description": "Can be displayed when using LLM Extraction. Warning message will let you know any issues with the extraction." + } + } + } + } + }, + "CrawlStatusResponseObj": { + "type": "object", + "properties": { + "markdown": { + "type": "string" + }, + "content": { + "type": "string" + }, + "html": { + "type": "string", + "nullable": true, + "description": "Raw HTML content of the page if `includeHtml` is true" + }, + "index": { + "type": "integer", + "description": "The number of the page that was crawled. This is useful for `partial_data` so you know which page the data is from." + }, + "metadata": { + "type": "object", + "properties": { + "title": { + "type": "string" + }, + "description": { + "type": "string" + }, + "language": { + "type": "string", + "nullable": true + }, + "keywords": { + "type": "string", + "nullable": true + }, + "robots": { + "type": "string", + "nullable": true + }, + "ogTitle": { + "type": "string", + "nullable": true + }, + "ogDescription": { + "type": "string", + "nullable": true + }, + "ogUrl": { + "type": "string", + "format": "uri", + "nullable": true + }, + "ogImage": { + "type": "string", + "nullable": true + }, + "ogAudio": { + "type": "string", + "nullable": true + }, + "ogDeterminer": { + "type": "string", + "nullable": true + }, + "ogLocale": { + "type": "string", + "nullable": true + }, + "ogLocaleAlternate": { + "type": "array", + "items": { + "type": "string" + }, + "nullable": true + }, + "ogSiteName": { + "type": "string", + "nullable": true + }, + "ogVideo": { + "type": "string", + "nullable": true + }, + "dctermsCreated": { + "type": "string", + "nullable": true + }, + "dcDateCreated": { + "type": "string", + "nullable": true + }, + "dcDate": { + "type": "string", + "nullable": true + }, + "dctermsType": { + "type": "string", + "nullable": true + }, + "dcType": { + "type": "string", + "nullable": true + }, + "dctermsAudience": { + "type": "string", + "nullable": true + }, + "dctermsSubject": { + "type": "string", + "nullable": true + }, + "dcSubject": { + "type": "string", + "nullable": true + }, + "dcDescription": { + "type": "string", + "nullable": true + }, + "dctermsKeywords": { + "type": "string", + "nullable": true + }, + "modifiedTime": { + "type": "string", + "nullable": true + }, + "publishedTime": { + "type": "string", + "nullable": true + }, + "articleTag": { + "type": "string", + "nullable": true + }, + "articleSection": { + "type": "string", + "nullable": true + }, + "sourceURL": { + "type": "string", + "format": "uri" + }, + "pageStatusCode": { + "type": "integer", + "description": "The status code of the page" + }, + "pageError": { + "type": "string", + "nullable": true, + "description": "The error message of the page" + } + } + } + } + }, + "SearchResponse": { + "type": "object", + "properties": { + "success": { + "type": "boolean" + }, + "data": { + "type": "array", + "items": { + "type": "object", + "properties": { + "url": { + "type": "string" + }, + "markdown": { + "type": "string" + }, + "content": { + "type": "string" + }, + "metadata": { + "type": "object", + "properties": { + "title": { + "type": "string" + }, + "description": { + "type": "string" + }, + "language": { + "type": "string", + "nullable": true + }, + "sourceURL": { + "type": "string", + "format": "uri" + } + } + } + } + } + } + } + }, + "CrawlResponse": { + "type": "object", + "properties": { + "jobId": { + "type": "string" + } + } + } + } + }, + "security": [ + { + "bearerAuth": [] + } + ] +} \ No newline at end of file diff --git a/test/test_files/json/openapi_order_service.json b/test/test_files/json/openapi_order_service.json new file mode 100644 index 00000000..3e24661a --- /dev/null +++ b/test/test_files/json/openapi_order_service.json @@ -0,0 +1,109 @@ +{ + "openapi": "3.0.0", + "info": { + "title": "Order Service", + "version": "1.0.0" + }, + "servers": [{"url": "http://localhost"}], + "paths": { + "/orders": { + "post": { + "summary": "Create a new order", + "operationId": "createOrder", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Order" + } + } + } + }, + "responses": { + "201": { + "description": "Created", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OrderResponse" + } + } + } + } + } + } + } + }, + "components": { + "schemas": { + "Order": { + "type": "object", + "properties": { + "customer": { + "$ref": "#/components/schemas/Customer" + }, + "items": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OrderItem" + } + } + }, + "required": [ + "customer", + "items" + ] + }, + "Customer": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "email": { + "type": "string" + } + }, + "required": [ + "name", + "email" + ] + }, + "OrderItem": { + "type": "object", + "properties": { + "product": { + "type": "string" + }, + "quantity": { + "type": "integer" + } + }, + "required": [ + "product", + "quantity" + ] + }, + "OrderResponse": { + "type": "object", + "properties": { + "orderId": { + "type": "string" + }, + "status": { + "type": "string" + }, + "totalAmount": { + "type": "number" + } + }, + "required": [ + "orderId", + "status", + "totalAmount" + ] + } + } + } +} \ No newline at end of file diff --git a/test/test_files/json/serperdev_openapi_spec.json b/test/test_files/json/serperdev_openapi_spec.json new file mode 100644 index 00000000..123c993a --- /dev/null +++ b/test/test_files/json/serperdev_openapi_spec.json @@ -0,0 +1,62 @@ +{ + "openapi": "3.0.0", + "info": { + "title": "SerperDev", + "version": "1.0.0", + "description": "API for performing search queries" + }, + "servers": [ + { + "url": "https://google.serper.dev" + } + ], + "paths": { + "/search": { + "post": { + "operationId": "search", + "description": "Search the web with Google", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "q": { + "type": "string" + } + } + } + } + } + }, + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "object" + } + } + } + } + }, + "security": [ + { + "apikey": [] + } + ] + } + } + }, + "components": { + "securitySchemes": { + "apikey": { + "type": "apiKey", + "name": "x-api-key", + "in": "header" + } + } + } +} diff --git a/test/test_files/yaml/github_compare.yml b/test/test_files/yaml/github_compare.yml new file mode 100644 index 00000000..e14575b7 --- /dev/null +++ b/test/test_files/yaml/github_compare.yml @@ -0,0 +1,438 @@ +openapi: 3.1.0 +info: + title: Github API + description: Enables interaction with OpenAPI + version: v1.0.0 +servers: + - url: https://api.github.com +paths: + /repos/{owner}/{repo}/compare/{basehead}: + get: + summary: Compare two branches + description: Compares two branches against one another. + tags: + - repos + operationId: compare_branches + externalDocs: + description: API method documentation + url: >- + https://docs.github.com/enterprise-server@3.9/rest/commits/commits#compare-two-commits + parameters: + - name: basehead + description: >- + The base branch and head branch to compare. This parameter expects + the format `BASE...HEAD` + in: path + required: true + x-multi-segment: true + schema: + type: string + - name: owner + description: The repository owner, usually a company or orgnization + in: path + required: true + x-multi-segment: true + schema: + type: string + - name: repo + description: The repository itself, the project + in: path + required: true + x-multi-segment: true + schema: + type: string + responses: + '200': + description: Response + content: + application/json: + schema: + $ref: '#/components/schemas/commit-comparison' + x-github: + githubCloudOnly: false + enabledForGitHubApps: true + category: commits + subcategory: commits +components: + schemas: + commit-comparison: + title: Commit Comparison + description: Commit Comparison + type: object + properties: + url: + type: string + format: uri + example: >- + https://api.github.com/repos/octocat/Hello-World/compare/master...topic + html_url: + type: string + format: uri + example: https://github.com/octocat/Hello-World/compare/master...topic + permalink_url: + type: string + format: uri + example: >- + https://github.com/octocat/Hello-World/compare/octocat:bbcd538c8e72b8c175046e27cc8f907076331401...octocat:0328041d1152db8ae77652d1618a02e57f745f17 + diff_url: + type: string + format: uri + example: https://github.com/octocat/Hello-World/compare/master...topic.diff + patch_url: + type: string + format: uri + example: https://github.com/octocat/Hello-World/compare/master...topic.patch + base_commit: + $ref: '#/components/schemas/commit' + merge_base_commit: + $ref: '#/components/schemas/commit' + status: + type: string + enum: + - diverged + - ahead + - behind + - identical + example: ahead + ahead_by: + type: integer + example: 4 + behind_by: + type: integer + example: 5 + total_commits: + type: integer + example: 6 + commits: + type: array + items: + $ref: '#/components/schemas/commit' + files: + type: array + items: + $ref: '#/components/schemas/diff-entry' + required: + - url + - html_url + - permalink_url + - diff_url + - patch_url + - base_commit + - merge_base_commit + - status + - ahead_by + - behind_by + - total_commits + - commits + nullable-git-user: + title: Git User + description: Metaproperties for Git author/committer information. + type: object + properties: + name: + type: string + example: '"Chris Wanstrath"' + email: + type: string + example: '"chris@ozmm.org"' + date: + type: string + example: '"2007-10-29T02:42:39.000-07:00"' + nullable: true + nullable-simple-user: + title: Simple User + description: A GitHub user. + type: object + properties: + name: + nullable: true + type: string + email: + nullable: true + type: string + login: + type: string + example: octocat + id: + type: integer + example: 1 + node_id: + type: string + example: MDQ6VXNlcjE= + avatar_url: + type: string + format: uri + example: https://github.com/images/error/octocat_happy.gif + gravatar_id: + type: string + example: 41d064eb2195891e12d0413f63227ea7 + nullable: true + url: + type: string + format: uri + example: https://api.github.com/users/octocat + html_url: + type: string + format: uri + example: https://github.com/octocat + followers_url: + type: string + format: uri + example: https://api.github.com/users/octocat/followers + following_url: + type: string + example: https://api.github.com/users/octocat/following{/other_user} + gists_url: + type: string + example: https://api.github.com/users/octocat/gists{/gist_id} + starred_url: + type: string + example: https://api.github.com/users/octocat/starred{/owner}{/repo} + subscriptions_url: + type: string + format: uri + example: https://api.github.com/users/octocat/subscriptions + organizations_url: + type: string + format: uri + example: https://api.github.com/users/octocat/orgs + repos_url: + type: string + format: uri + example: https://api.github.com/users/octocat/repos + events_url: + type: string + example: https://api.github.com/users/octocat/events{/privacy} + received_events_url: + type: string + format: uri + example: https://api.github.com/users/octocat/received_events + type: + type: string + example: User + site_admin: + type: boolean + starred_at: + type: string + example: '"2020-07-09T00:17:55Z"' + required: + - avatar_url + - events_url + - followers_url + - following_url + - gists_url + - gravatar_id + - html_url + - id + - node_id + - login + - organizations_url + - received_events_url + - repos_url + - site_admin + - starred_url + - subscriptions_url + - type + - url + nullable: true + verification: + title: Verification + type: object + properties: + verified: + type: boolean + reason: + type: string + payload: + type: string + nullable: true + signature: + type: string + nullable: true + required: + - verified + - reason + - payload + - signature + diff-entry: + title: Diff Entry + description: Diff Entry + type: object + properties: + sha: + type: string + example: bbcd538c8e72b8c175046e27cc8f907076331401 + filename: + type: string + example: file1.txt + status: + type: string + enum: + - added + - removed + - modified + - renamed + - copied + - changed + - unchanged + example: added + additions: + type: integer + example: 103 + deletions: + type: integer + example: 21 + changes: + type: integer + example: 124 + blob_url: + type: string + format: uri + example: >- + https://github.com/octocat/Hello-World/blob/6dcb09b5b57875f334f61aebed695e2e4193db5e/file1.txt + raw_url: + type: string + format: uri + example: >- + https://github.com/octocat/Hello-World/raw/6dcb09b5b57875f334f61aebed695e2e4193db5e/file1.txt + contents_url: + type: string + format: uri + example: >- + https://api.github.com/repos/octocat/Hello-World/contents/file1.txt?ref=6dcb09b5b57875f334f61aebed695e2e4193db5e + patch: + type: string + example: '@@ -132,7 +132,7 @@ module Test @@ -1000,7 +1000,7 @@ module Test' + previous_filename: + type: string + example: file.txt + required: + - additions + - blob_url + - changes + - contents_url + - deletions + - filename + - raw_url + - sha + - status + commit: + title: Commit + description: Commit + type: object + properties: + url: + type: string + format: uri + example: >- + https://api.github.com/repos/octocat/Hello-World/commits/6dcb09b5b57875f334f61aebed695e2e4193db5e + sha: + type: string + example: 6dcb09b5b57875f334f61aebed695e2e4193db5e + node_id: + type: string + example: MDY6Q29tbWl0NmRjYjA5YjViNTc4NzVmMzM0ZjYxYWViZWQ2OTVlMmU0MTkzZGI1ZQ== + html_url: + type: string + format: uri + example: >- + https://github.com/octocat/Hello-World/commit/6dcb09b5b57875f334f61aebed695e2e4193db5e + comments_url: + type: string + format: uri + example: >- + https://api.github.com/repos/octocat/Hello-World/commits/6dcb09b5b57875f334f61aebed695e2e4193db5e/comments + commit: + type: object + properties: + url: + type: string + format: uri + example: >- + https://api.github.com/repos/octocat/Hello-World/commits/6dcb09b5b57875f334f61aebed695e2e4193db5e + author: + $ref: '#/components/schemas/nullable-git-user' + committer: + $ref: '#/components/schemas/nullable-git-user' + message: + type: string + example: Fix all the bugs + comment_count: + type: integer + example: 0 + tree: + type: object + properties: + sha: + type: string + example: 827efc6d56897b048c772eb4087f854f46256132 + url: + type: string + format: uri + example: >- + https://api.github.com/repos/octocat/Hello-World/tree/827efc6d56897b048c772eb4087f854f46256132 + required: + - sha + - url + verification: + $ref: '#/components/schemas/verification' + required: + - author + - committer + - comment_count + - message + - tree + - url + author: + $ref: '#/components/schemas/nullable-simple-user' + committer: + $ref: '#/components/schemas/nullable-simple-user' + parents: + type: array + items: + type: object + properties: + sha: + type: string + example: 7638417db6d59f3c431d3e1f261cc637155684cd + url: + type: string + format: uri + example: >- + https://api.github.com/repos/octocat/Hello-World/commits/7638417db6d59f3c431d3e1f261cc637155684cd + html_url: + type: string + format: uri + example: >- + https://github.com/octocat/Hello-World/commit/7638417db6d59f3c431d3e1f261cc637155684cd + required: + - sha + - url + stats: + type: object + properties: + additions: + type: integer + deletions: + type: integer + total: + type: integer + files: + type: array + items: + $ref: '#/components/schemas/diff-entry' + required: + - url + - sha + - node_id + - html_url + - comments_url + - commit + - author + - committer + - parents + securitySchemes: + apikey: + type: apiKey + name: x-api-key + in: header diff --git a/test/test_files/yaml/openapi_edge_cases.yml b/test/test_files/yaml/openapi_edge_cases.yml new file mode 100644 index 00000000..cef304c5 --- /dev/null +++ b/test/test_files/yaml/openapi_edge_cases.yml @@ -0,0 +1,13 @@ +openapi: 3.0.0 +info: + title: Edge Cases API + version: 1.0.0 +servers: + - url: http://localhost # not used anyway +paths: + /missing-operation-id: + get: + summary: Missing operationId + responses: + '200': + description: OK diff --git a/test/test_files/yaml/openapi_error_handling.yml b/test/test_files/yaml/openapi_error_handling.yml new file mode 100644 index 00000000..5cf23fe5 --- /dev/null +++ b/test/test_files/yaml/openapi_error_handling.yml @@ -0,0 +1,24 @@ +openapi: 3.0.0 +info: + title: Error Handling API + version: 1.0.0 +servers: + - url: http://localhost # not used anyway +paths: + /error/{status_code}: + get: + summary: Raise HTTP error + operationId: raiseHttpError + parameters: + - name: status_code + in: path + required: true + schema: + type: integer + responses: + '400': + description: Bad Request + '401': + description: Unauthorized + '404': + description: Not Found \ No newline at end of file diff --git a/test/test_files/yaml/openapi_greeting_service.yml b/test/test_files/yaml/openapi_greeting_service.yml new file mode 100644 index 00000000..701dee33 --- /dev/null +++ b/test/test_files/yaml/openapi_greeting_service.yml @@ -0,0 +1,272 @@ +openapi: 3.0.0 +info: + title: Greeting Service + version: 1.0.0 +servers: + - url: http://localhost # not used anyway +paths: + /greet/{name}: + post: + operationId: greet + parameters: + - name: name + in: path + required: true + schema: + type: string + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/MessageBody' + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/GreetingResponse' + + /greet-params/{name}: + get: + operationId: greetParams + parameters: + - name: name + in: path + required: true + schema: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/GreetingResponse' + + /greet-body: + post: + operationId: greetBody + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/GreetBody' + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/GreetingResponse' + + /greet-api-key/{name}: + get: + operationId: greetApiKey + security: + - ApiKeyAuth: [] + parameters: + - name: name + in: path + required: true + schema: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/GreetingResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + + /greet-basic-auth/{name}: + get: + operationId: greetBasicAuth + security: + - BasicAuth: [] + parameters: + - name: name + in: path + required: true + schema: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/GreetingResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /greet-api-key-query/{name}: + get: + operationId: greetApiKeyQuery + security: + - ApiKeyAuthQuery: [ ] + parameters: + - name: name + in: path + required: true + schema: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/GreetingResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + + /greet-api-key-cookie/{name}: + get: + operationId: greetApiKeyCookie + security: + - ApiKeyAuthCookie: [ ] + parameters: + - name: name + in: path + required: true + schema: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/GreetingResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + + /greet-bearer-auth/{name}: + get: + operationId: greetBearerAuth + security: + - BearerAuth: [ ] + parameters: + - name: name + in: path + required: true + schema: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/GreetingResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + + /greet-oauth/{name}: + get: + operationId: greetOAuth + security: + - OAuth2: [ ] + parameters: + - name: name + in: path + required: true + schema: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/GreetingResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' +components: + securitySchemes: + ApiKeyAuth: + type: apiKey + in: header + name: X-API-Key + BasicAuth: + type: http + scheme: basic + ApiKeyAuthQuery: + type: apiKey + in: query + name: api_key + ApiKeyAuthCookie: + type: apiKey + in: cookie + name: api_key + BearerAuth: + type: http + scheme: bearer + OAuth2: + type: oauth2 + flows: + authorizationCode: + authorizationUrl: https://example.com/oauth/authorize + tokenUrl: https://example.com/oauth/token + scopes: + read:greet: Read access to greeting service + + schemas: + GreetBody: + type: object + properties: + message: + type: string + name: + type: string + required: + - message + - name + + MessageBody: + type: object + properties: + message: + type: string + required: + - message + + GreetingResponse: + type: object + properties: + greeting: + type: string + + ErrorResponse: + type: object + properties: + detail: + type: string \ No newline at end of file diff --git a/test/test_files/yaml/openapi_order_service.yml b/test/test_files/yaml/openapi_order_service.yml new file mode 100644 index 00000000..07360ea5 --- /dev/null +++ b/test/test_files/yaml/openapi_order_service.yml @@ -0,0 +1,75 @@ +openapi: 3.0.0 +info: + title: Order Service + version: 1.0.0 +servers: + - url: http://localhost # not used anyway +paths: + /orders: + post: + summary: Create a new order + operationId: createOrder + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/Order' + responses: + '201': + description: Created + content: + application/json: + schema: + $ref: '#/components/schemas/OrderResponse' + +components: + schemas: + Order: + type: object + properties: + customer: + $ref: '#/components/schemas/Customer' + items: + type: array + items: + $ref: '#/components/schemas/OrderItem' + required: + - customer + - items + + Customer: + type: object + properties: + name: + type: string + email: + type: string + required: + - name + - email + + OrderItem: + type: object + properties: + product: + type: string + quantity: + type: integer + required: + - product + - quantity + + OrderResponse: + type: object + properties: + orderId: + type: string + status: + type: string + totalAmount: + type: number + required: + - orderId + - status + - totalAmount \ No newline at end of file diff --git a/test/test_files/yaml/serper.yml b/test/test_files/yaml/serper.yml new file mode 100644 index 00000000..9d5b1a85 --- /dev/null +++ b/test/test_files/yaml/serper.yml @@ -0,0 +1,39 @@ +openapi: 3.0.0 +info: + title: SerperDev + version: 1.0.0 + description: API for performing search queries +servers: + - url: https://google.serper.dev +paths: + /search: + post: + operationId: serperdev_search + description: Search the web with Google + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + q: + type: string + required: + - q + responses: + '200': + description: Successful response + content: + application/json: + schema: + type: object + additionalProperties: true + security: + - apikey: [] +components: + securitySchemes: + apikey: + type: apiKey + name: x-api-key + in: header diff --git a/test/util/__init__.py b/test/util/__init__.py new file mode 100644 index 00000000..c1764a6e --- /dev/null +++ b/test/util/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0