Skip to content

Commit

Permalink
use getattr when given str in getitem (#4761)
Browse files Browse the repository at this point in the history
* use getattr when given str in getitem

* stronger checking and tests

* switch ordering

* use safe issubclass

* calculate origin differently
  • Loading branch information
adhami3310 authored Feb 6, 2025
1 parent 6f4d328 commit 1651289
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 4 deletions.
17 changes: 13 additions & 4 deletions reflex/vars/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@

from reflex.utils import types
from reflex.utils.exceptions import VarAttributeError
from reflex.utils.types import GenericType, get_attribute_access_type, get_origin
from reflex.utils.types import (
GenericType,
get_attribute_access_type,
get_origin,
safe_issubclass,
)

from .base import (
CachedVarOperation,
Expand Down Expand Up @@ -187,10 +192,14 @@ def __getitem__(self, key: Var | Any) -> Var:
Returns:
The item from the object.
"""
from .sequence import LiteralStringVar

if not isinstance(key, (StringVar, str, int, NumberVar)) or (
isinstance(key, NumberVar) and key._is_strict_float()
):
raise_unsupported_operand_types("[]", (type(self), type(key)))
if isinstance(key, str) and isinstance(Var.create(key), LiteralStringVar):
return self.__getattr__(key)
return ObjectItemOperation.create(self, key).guess_type()

# NoReturn is used here to catch when key value is Any
Expand Down Expand Up @@ -260,12 +269,12 @@ def __getattr__(self, name: str) -> Var:
if types.is_optional(var_type):
var_type = get_args(var_type)[0]

fixed_type = var_type if isclass(var_type) else get_origin(var_type)
fixed_type = get_origin(var_type) or var_type

if (
(isclass(fixed_type) and not issubclass(fixed_type, Mapping))
is_typeddict(fixed_type)
or (isclass(fixed_type) and not safe_issubclass(fixed_type, Mapping))
or (fixed_type in types.UnionTypes)
or is_typeddict(fixed_type)
):
attribute_type = get_attribute_access_type(var_type, name)
if attribute_type is None:
Expand Down
19 changes: 19 additions & 0 deletions tests/integration/test_var_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,19 @@

def VarOperations():
"""App with var operations."""
from typing import TypedDict

import reflex as rx
from reflex.vars.base import LiteralVar
from reflex.vars.sequence import ArrayVar

class Object(rx.Base):
name: str = "hello"

class Person(TypedDict):
name: str
age: int

class VarOperationState(rx.State):
int_var1: rx.Field[int] = rx.field(10)
int_var2: rx.Field[int] = rx.field(5)
Expand All @@ -34,6 +40,9 @@ class VarOperationState(rx.State):
dict1: rx.Field[dict[int, int]] = rx.field({1: 2})
dict2: rx.Field[dict[int, int]] = rx.field({3: 4})
html_str: rx.Field[str] = rx.field("<div>hello</div>")
people: rx.Field[list[Person]] = rx.field(
[{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}]
)

app = rx.App(_state=rx.State)

Expand Down Expand Up @@ -619,6 +628,15 @@ def index():
),
id="dict_in_foreach3",
),
rx.box(
rx.foreach(
VarOperationState.people,
lambda person: rx.text.span(
"Hello " + person["name"], person["age"] + 3
),
),
id="typed_dict_in_foreach",
),
)


Expand Down Expand Up @@ -826,6 +844,7 @@ def test_var_operations(driver, var_operations: AppHarness):
("dict_in_foreach1", "a1b2"),
("dict_in_foreach2", "12"),
("dict_in_foreach3", "1234"),
("typed_dict_in_foreach", "Hello Alice33Hello Bob28"),
]

for tag, expected in tests:
Expand Down

0 comments on commit 1651289

Please sign in to comment.