Skip to content

Commit

Permalink
Suggest name from scope, if one exists (#751)
Browse files Browse the repository at this point in the history
  • Loading branch information
hauntsaninja authored Mar 11, 2024
1 parent f107b01 commit a0fdbcf
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 14 deletions.
4 changes: 2 additions & 2 deletions pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2057,7 +2057,7 @@ def visit_FunctionDef(self, node: FunctionDefNode) -> Value:
):
prepared = prepare_type(result.return_value)
if should_suggest_type(prepared):
detail, metadata = display_suggested_type(prepared)
detail, metadata = display_suggested_type(prepared, self.scopes)
self._show_error_if_checking(
node,
error_code=ErrorCode.suggested_return_type,
Expand Down Expand Up @@ -2183,7 +2183,7 @@ def _get_potential_function(self, node: FunctionDefNode) -> Optional[object]:
sig = self.signature_from_value(KnownValue(potential_function))
if isinstance(sig, Signature):
self.checker.callable_tracker.record_callable(
node, potential_function, sig, self
node, potential_function, sig, scopes=self.scopes, ctx=self
)
return potential_function

Expand Down
38 changes: 26 additions & 12 deletions pyanalyze/suggested_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .node_visitor import ErrorContext, Failure
from .safe import safe_getattr, safe_isinstance
from .signature import Signature
from .stacked_scopes import StackedScopes, VisitorState
from .value import (
NO_RETURN_VALUE,
AnnotatedValue,
Expand Down Expand Up @@ -44,6 +45,7 @@ class CallableData:
node: FunctionNode
ctx: ErrorContext
sig: Signature
scopes: StackedScopes
calls: List[CallArgs] = field(default_factory=list)

def check(self) -> Iterator[Failure]:
Expand All @@ -65,7 +67,7 @@ def check(self) -> Iterator[Failure]:
suggested = unite_values(*all_values)
if not should_suggest_type(suggested):
continue
detail, metadata = display_suggested_type(suggested)
detail, metadata = display_suggested_type(suggested, self.scopes)
failure = self.ctx.show_error(
param,
f"Suggested type for parameter {param.arg}",
Expand All @@ -89,10 +91,15 @@ class CallableTracker:
)

def record_callable(
self, node: FunctionNode, callable: object, sig: Signature, ctx: ErrorContext
self,
node: FunctionNode,
callable: object,
sig: Signature,
scopes: StackedScopes,
ctx: ErrorContext,
) -> None:
"""Record when we encounter a callable."""
self.callable_to_data[callable] = CallableData(node, ctx, sig)
self.callable_to_data[callable] = CallableData(node, ctx, sig, scopes)

def record_call(self, callable: object, arguments: Mapping[str, Value]) -> None:
"""Record the actual arguments passed in in a call."""
Expand All @@ -108,7 +115,9 @@ def check(self) -> List[Failure]:
return failures


def display_suggested_type(value: Value) -> Tuple[str, Optional[Dict[str, Any]]]:
def display_suggested_type(
value: Value, scopes: StackedScopes
) -> Tuple[str, Optional[Dict[str, Any]]]:
value = prepare_type(value)
if isinstance(value, MultiValuedValue) and value.vals:
cae = CanAssignError("Union", [CanAssignError(str(val)) for val in value.vals])
Expand All @@ -122,14 +131,19 @@ def display_suggested_type(value: Value) -> Tuple[str, Optional[Dict[str, Any]]]
# exist, and we should be using a Callable type instead anyway.
metadata = None
else:
suggested_type = stringify_object(value.typ)
imports = []
if isinstance(value.typ, str):
if "." in value.typ:
imports.append(value.typ)
elif safe_getattr(value.typ, "__module__", None) != "builtins":
imports.append(suggested_type.split(".")[0])
metadata = {"suggested_type": suggested_type, "imports": imports}
typ_str = stringify_object(value.typ)
typ_name = typ_str.split(".")[-1]
scope_value = scopes.get(typ_name, None, VisitorState.check_names)
if isinstance(scope_value, KnownValue) and scope_value.val is value.typ:
metadata = {"suggested_type": typ_name, "imports": []}
else:
imports = []
if isinstance(value.typ, str):
if "." in value.typ:
imports.append(value.typ)
elif safe_getattr(value.typ, "__module__", None) != "builtins":
imports.append(typ_str.split(".")[0])
metadata = {"suggested_type": typ_str, "imports": imports}
else:
metadata = None
return str(cae), metadata
Expand Down

0 comments on commit a0fdbcf

Please sign in to comment.