From 75122646b556fd62abaef57bfcd26f11638f88c3 Mon Sep 17 00:00:00 2001 From: Naman Jain Date: Tue, 12 Mar 2024 11:12:45 +0530 Subject: [PATCH] core[patch]: fixed circular dependency with json schema (#18657) **Description:** Circular dependencies when parsing references leading to `RecursionError: maximum recursion depth exceeded` issue. This PR address the issue by handling previously seen refs as in any typical DFS to avoid infinite depths. **Issue:** https://github.com/langchain-ai/langchain/issues/12163 **Twitter handle:** https://twitter.com/theBhulawat - [x] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ --------- Co-authored-by: Bagatur Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> --- libs/core/langchain_core/utils/json_schema.py | 44 ++++++++++++---- .../unit_tests/utils/test_json_schema.py | 51 +++++++++++++++++++ 2 files changed, 86 insertions(+), 9 deletions(-) diff --git a/libs/core/langchain_core/utils/json_schema.py b/libs/core/langchain_core/utils/json_schema.py index 95313b3f95493..aecdae470d849 100644 --- a/libs/core/langchain_core/utils/json_schema.py +++ b/libs/core/langchain_core/utils/json_schema.py @@ -1,7 +1,7 @@ from __future__ import annotations from copy import deepcopy -from typing import Any, List, Optional, Sequence +from typing import Any, Dict, List, Optional, Sequence, Set def _retrieve_ref(path: str, schema: dict) -> dict: @@ -21,40 +21,66 @@ def _retrieve_ref(path: str, schema: dict) -> dict: def _dereference_refs_helper( - obj: Any, full_schema: dict, skip_keys: Sequence[str] + obj: Any, + full_schema: Dict[str, Any], + skip_keys: Sequence[str], + processed_refs: Optional[Set[str]] = None, ) -> Any: + if processed_refs is None: + processed_refs = set() + if isinstance(obj, dict): obj_out = {} for k, v in obj.items(): if k in skip_keys: obj_out[k] = v elif k == "$ref": + if v in processed_refs: + continue + processed_refs.add(v) ref = _retrieve_ref(v, full_schema) - return _dereference_refs_helper(ref, full_schema, skip_keys) + full_ref = _dereference_refs_helper( + ref, full_schema, skip_keys, processed_refs + ) + processed_refs.remove(v) + return full_ref elif isinstance(v, (list, dict)): - obj_out[k] = _dereference_refs_helper(v, full_schema, skip_keys) + obj_out[k] = _dereference_refs_helper( + v, full_schema, skip_keys, processed_refs + ) else: obj_out[k] = v return obj_out elif isinstance(obj, list): - return [_dereference_refs_helper(el, full_schema, skip_keys) for el in obj] + return [ + _dereference_refs_helper(el, full_schema, skip_keys, processed_refs) + for el in obj + ] else: return obj -def _infer_skip_keys(obj: Any, full_schema: dict) -> List[str]: +def _infer_skip_keys( + obj: Any, full_schema: dict, processed_refs: Optional[Set[str]] = None +) -> List[str]: + if processed_refs is None: + processed_refs = set() + keys = [] if isinstance(obj, dict): for k, v in obj.items(): if k == "$ref": + if v in processed_refs: + continue + processed_refs.add(v) ref = _retrieve_ref(v, full_schema) keys.append(v.split("/")[1]) - keys += _infer_skip_keys(ref, full_schema) + keys += _infer_skip_keys(ref, full_schema, processed_refs) elif isinstance(v, (list, dict)): - keys += _infer_skip_keys(v, full_schema) + keys += _infer_skip_keys(v, full_schema, processed_refs) elif isinstance(obj, list): for el in obj: - keys += _infer_skip_keys(el, full_schema) + keys += _infer_skip_keys(el, full_schema, processed_refs) return keys diff --git a/libs/core/tests/unit_tests/utils/test_json_schema.py b/libs/core/tests/unit_tests/utils/test_json_schema.py index 6f0a00fd99d2f..44ad61609b5a5 100644 --- a/libs/core/tests/unit_tests/utils/test_json_schema.py +++ b/libs/core/tests/unit_tests/utils/test_json_schema.py @@ -181,3 +181,54 @@ def test_dereference_refs_integer_ref() -> None: } actual = dereference_refs(schema) assert actual == expected + + +def test_dereference_refs_cyclical_refs() -> None: + schema = { + "type": "object", + "properties": { + "user": {"$ref": "#/$defs/user"}, + "customer": {"$ref": "#/$defs/user"}, + }, + "$defs": { + "user": { + "type": "object", + "properties": { + "friends": {"type": "array", "items": {"$ref": "#/$defs/user"}} + }, + } + }, + } + expected = { + "type": "object", + "properties": { + "user": { + "type": "object", + "properties": { + "friends": { + "type": "array", + "items": {}, # Recursion is broken here + } + }, + }, + "customer": { + "type": "object", + "properties": { + "friends": { + "type": "array", + "items": {}, # Recursion is broken here + } + }, + }, + }, + "$defs": { + "user": { + "type": "object", + "properties": { + "friends": {"type": "array", "items": {"$ref": "#/$defs/user"}} + }, + } + }, + } + actual = dereference_refs(schema) + assert actual == expected