Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: Add ruff rule FBT003 (boolean-trap) #29424

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions libs/core/langchain_core/_api/deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ class LangChainPendingDeprecationWarning(PendingDeprecationWarning):


def _validate_deprecation_params(
pending: bool,
removal: str,
alternative: str,
alternative_import: str,
*,
pending: bool,
) -> None:
"""Validate the deprecation parameters."""
if pending and removal:
Expand Down Expand Up @@ -130,7 +131,9 @@ def deprecated(
def the_function_to_deprecate():
pass
"""
_validate_deprecation_params(pending, removal, alternative, alternative_import)
_validate_deprecation_params(
removal, alternative, alternative_import, pending=pending
)

def deprecate(
obj: T,
Expand Down
25 changes: 13 additions & 12 deletions libs/core/langchain_core/callbacks/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def get_child(self, tag: Optional[str] = None) -> CallbackManager:
manager.add_tags(self.inheritable_tags)
manager.add_metadata(self.inheritable_metadata)
if tag is not None:
manager.add_tags([tag], False)
manager.add_tags([tag], inherit=False)
return manager


Expand Down Expand Up @@ -641,7 +641,7 @@ def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager:
manager.add_tags(self.inheritable_tags)
manager.add_metadata(self.inheritable_metadata)
if tag is not None:
manager.add_tags([tag], False)
manager.add_tags([tag], inherit=False)
return manager


Expand Down Expand Up @@ -1563,11 +1563,11 @@ def configure(
cls,
inheritable_callbacks,
local_callbacks,
verbose,
inheritable_tags,
local_tags,
inheritable_metadata,
local_metadata,
verbose=verbose,
)


Expand Down Expand Up @@ -2087,11 +2087,11 @@ def configure(
cls,
inheritable_callbacks,
local_callbacks,
verbose,
inheritable_tags,
local_tags,
inheritable_metadata,
local_metadata,
verbose=verbose,
)


Expand Down Expand Up @@ -2236,11 +2236,12 @@ def _configure(
callback_manager_cls: type[T],
inheritable_callbacks: Callbacks = None,
local_callbacks: Callbacks = None,
verbose: bool = False,
inheritable_tags: Optional[list[str]] = None,
local_tags: Optional[list[str]] = None,
inheritable_metadata: Optional[dict[str, Any]] = None,
local_metadata: Optional[dict[str, Any]] = None,
*,
verbose: bool = False,
) -> T:
"""Configure the callback manager.

Expand Down Expand Up @@ -2314,13 +2315,13 @@ def _configure(
else (local_callbacks.handlers if local_callbacks else [])
)
for handler in local_handlers_:
callback_manager.add_handler(handler, False)
callback_manager.add_handler(handler, inherit=False)
if inheritable_tags or local_tags:
callback_manager.add_tags(inheritable_tags or [])
callback_manager.add_tags(local_tags or [], False)
callback_manager.add_tags(local_tags or [], inherit=False)
if inheritable_metadata or local_metadata:
callback_manager.add_metadata(inheritable_metadata or {})
callback_manager.add_metadata(local_metadata or {}, False)
callback_manager.add_metadata(local_metadata or {}, inherit=False)
if tracing_metadata:
callback_manager.add_metadata(tracing_metadata.copy())
if tracing_tags:
Expand Down Expand Up @@ -2355,18 +2356,18 @@ def _configure(
if debug:
pass
else:
callback_manager.add_handler(StdOutCallbackHandler(), False)
callback_manager.add_handler(StdOutCallbackHandler(), inherit=False)
if debug and not any(
isinstance(handler, ConsoleCallbackHandler)
for handler in callback_manager.handlers
):
callback_manager.add_handler(ConsoleCallbackHandler(), True)
callback_manager.add_handler(ConsoleCallbackHandler())
if tracing_v2_enabled_ and not any(
isinstance(handler, LangChainTracer)
for handler in callback_manager.handlers
):
if tracer_v2:
callback_manager.add_handler(tracer_v2, True)
callback_manager.add_handler(tracer_v2)
else:
try:
handler = LangChainTracer(
Expand All @@ -2378,7 +2379,7 @@ def _configure(
),
tags=tracing_tags,
)
callback_manager.add_handler(handler, True)
callback_manager.add_handler(handler)
except Exception as e:
logger.warning(
"Unable to load requested LangChainTracer."
Expand Down
5 changes: 3 additions & 2 deletions libs/core/langchain_core/language_models/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,7 @@ async def _agenerate_helper(
prompts: list[str],
stop: Optional[list[str]],
run_managers: list[AsyncCallbackManagerForLLMRun],
*,
new_arg_supported: bool,
**kwargs: Any,
) -> LLMResult:
Expand Down Expand Up @@ -1212,7 +1213,7 @@ async def agenerate(
prompts,
stop,
run_managers, # type: ignore[arg-type]
bool(new_arg_supported),
new_arg_supported=bool(new_arg_supported),
**kwargs, # type: ignore[arg-type]
)
return output
Expand All @@ -1235,7 +1236,7 @@ async def agenerate(
missing_prompts,
stop,
run_managers, # type: ignore[arg-type]
bool(new_arg_supported),
new_arg_supported=bool(new_arg_supported),
**kwargs, # type: ignore[arg-type]
)
llm_output = await aupdate_cache(
Expand Down
8 changes: 5 additions & 3 deletions libs/core/langchain_core/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _get_filtered_args(


def _parse_python_function_docstring(
function: Callable, annotations: dict, error_on_invalid_docstring: bool = False
function: Callable, annotations: dict, *, error_on_invalid_docstring: bool = False
) -> tuple[str, dict]:
"""Parse the function and argument descriptions from the docstring of a function.

Expand Down Expand Up @@ -1073,7 +1073,7 @@ def get_all_basemodel_annotations(
generic_map = dict(zip(generic_type_vars, get_args(parent)))
for field in getattr(parent_origin, "__annotations__", {}):
annotations[field] = _replace_type_vars(
annotations[field], generic_map, default_to_bound
annotations[field], generic_map, default_to_bound=default_to_bound
)

return {
Expand All @@ -1085,6 +1085,7 @@ def get_all_basemodel_annotations(
def _replace_type_vars(
type_: type,
generic_map: Optional[dict[TypeVar, type]] = None,
*,
default_to_bound: bool = True,
) -> type:
generic_map = generic_map or {}
Expand All @@ -1097,7 +1098,8 @@ def _replace_type_vars(
return type_
elif (origin := get_origin(type_)) and (args := get_args(type_)):
new_args = tuple(
_replace_type_vars(arg, generic_map, default_to_bound) for arg in args
_replace_type_vars(arg, generic_map, default_to_bound=default_to_bound)
for arg in args
)
return _py_38_safe_origin(origin)[new_args] # type: ignore[index]
else:
Expand Down
4 changes: 2 additions & 2 deletions libs/core/langchain_core/tracers/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _get_trace_callbacks(
isinstance(handler, LangChainTracer)
for handler in callback_manager.handlers
):
callback_manager.add_handler(tracer, True)
callback_manager.add_handler(tracer)
# If it already has a LangChainTracer, we don't need to add another one.
# this would likely mess up the trace hierarchy.
cb = callback_manager
Expand Down Expand Up @@ -217,4 +217,4 @@ def register_configure_hook(
)


register_configure_hook(run_collector_var, False)
register_configure_hook(run_collector_var, inheritable=False)
1 change: 1 addition & 0 deletions libs/core/langchain_core/utils/mustache.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ def _html_escape(string: str) -> str:
def _get_key(
key: str,
scopes: Scopes,
*,
warn: bool,
keep: bool,
def_ldel: str,
Expand Down
2 changes: 1 addition & 1 deletion libs/core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ python = ">=3.12.4"
[tool.poetry.extras]

[tool.ruff.lint]
select = [ "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "Q", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TID", "TRY", "UP", "W", "YTT",]
select = [ "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FBT003", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "Q", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TID", "TRY", "UP", "W", "YTT",]
ignore = [ "COM812", "UP007", "S110", "S112",]

[tool.coverage.run]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class CustomHandler(AsyncCallbackHandler):
called.
"""

def __init__(self, run_inline: bool) -> None:
def __init__(self, *, run_inline: bool) -> None:
"""Initialize the handler."""
self.run_inline = run_inline

Expand Down Expand Up @@ -91,7 +91,7 @@ async def set_counter_var() -> Any:
counter_var.reset(token)

class StatefulAsyncCallbackHandler(AsyncCallbackHandler):
def __init__(self, name: str, run_inline: bool = True):
def __init__(self, name: str, *, run_inline: bool = True):
self.name = name
self.run_inline = run_inline

Expand Down
22 changes: 11 additions & 11 deletions libs/core/tests/unit_tests/output_parsers/test_openai_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@
]


def _get_iter(use_tool_calls: bool = False) -> Any:
def _get_iter(*, use_tool_calls: bool = False) -> Any:
if use_tool_calls:
list_to_iter = STREAMED_MESSAGES_WITH_TOOL_CALLS
else:
Expand All @@ -374,7 +374,7 @@ def input_iter(_: Any) -> Iterator[BaseMessage]:
return input_iter


def _get_aiter(use_tool_calls: bool = False) -> Any:
def _get_aiter(*, use_tool_calls: bool = False) -> Any:
if use_tool_calls:
list_to_iter = STREAMED_MESSAGES_WITH_TOOL_CALLS
else:
Expand All @@ -389,7 +389,7 @@ async def input_iter(_: Any) -> AsyncIterator[BaseMessage]:

def test_partial_json_output_parser() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_iter(use_tool_calls)
input_iter = _get_iter(use_tool_calls=use_tool_calls)
chain = input_iter | JsonOutputToolsParser()

actual = list(chain.stream(None))
Expand All @@ -402,7 +402,7 @@ def test_partial_json_output_parser() -> None:

async def test_partial_json_output_parser_async() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_aiter(use_tool_calls)
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
chain = input_iter | JsonOutputToolsParser()

actual = [p async for p in chain.astream(None)]
Expand All @@ -415,7 +415,7 @@ async def test_partial_json_output_parser_async() -> None:

def test_partial_json_output_parser_return_id() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_iter(use_tool_calls)
input_iter = _get_iter(use_tool_calls=use_tool_calls)
chain = input_iter | JsonOutputToolsParser(return_id=True)

actual = list(chain.stream(None))
Expand All @@ -434,7 +434,7 @@ def test_partial_json_output_parser_return_id() -> None:

def test_partial_json_output_key_parser() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_iter(use_tool_calls)
input_iter = _get_iter(use_tool_calls=use_tool_calls)
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")

actual = list(chain.stream(None))
Expand All @@ -444,7 +444,7 @@ def test_partial_json_output_key_parser() -> None:

async def test_partial_json_output_parser_key_async() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_aiter(use_tool_calls)
input_iter = _get_aiter(use_tool_calls=use_tool_calls)

chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")

Expand All @@ -455,7 +455,7 @@ async def test_partial_json_output_parser_key_async() -> None:

def test_partial_json_output_key_parser_first_only() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_iter(use_tool_calls)
input_iter = _get_iter(use_tool_calls=use_tool_calls)

chain = input_iter | JsonOutputKeyToolsParser(
key_name="NameCollector", first_tool_only=True
Expand All @@ -466,7 +466,7 @@ def test_partial_json_output_key_parser_first_only() -> None:

async def test_partial_json_output_parser_key_async_first_only() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_aiter(use_tool_calls)
input_iter = _get_aiter(use_tool_calls=use_tool_calls)

chain = input_iter | JsonOutputKeyToolsParser(
key_name="NameCollector", first_tool_only=True
Expand Down Expand Up @@ -507,7 +507,7 @@ class NameCollector(BaseModel):

def test_partial_pydantic_output_parser() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_iter(use_tool_calls)
input_iter = _get_iter(use_tool_calls=use_tool_calls)

chain = input_iter | PydanticToolsParser(
tools=[NameCollector], first_tool_only=True
Expand All @@ -519,7 +519,7 @@ def test_partial_pydantic_output_parser() -> None:

async def test_partial_pydantic_output_parser_async() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_aiter(use_tool_calls)
input_iter = _get_aiter(use_tool_calls=use_tool_calls)

chain = input_iter | PydanticToolsParser(
tools=[NameCollector], first_tool_only=True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ async def _as_async_iterator(iterable: list) -> AsyncIterator:


async def _collect_events(
events: AsyncIterator[StreamEvent], with_nulled_ids: bool = True
events: AsyncIterator[StreamEvent], *, with_nulled_ids: bool = True
) -> list[StreamEvent]:
"""Collect the events and remove the run ids."""
materialized_events = [event async for event in events]
Expand Down
Loading