diff --git a/reflex/vars/object.py b/reflex/vars/object.py index cb29cabfb8..89479bbc41 100644 --- a/reflex/vars/object.py +++ b/reflex/vars/object.py @@ -22,7 +22,12 @@ from reflex.utils import types from reflex.utils.exceptions import VarAttributeError -from reflex.utils.types import GenericType, get_attribute_access_type, get_origin +from reflex.utils.types import ( + GenericType, + get_attribute_access_type, + get_origin, + safe_issubclass, +) from .base import ( CachedVarOperation, @@ -187,10 +192,14 @@ def __getitem__(self, key: Var | Any) -> Var: Returns: The item from the object. """ + from .sequence import LiteralStringVar + if not isinstance(key, (StringVar, str, int, NumberVar)) or ( isinstance(key, NumberVar) and key._is_strict_float() ): raise_unsupported_operand_types("[]", (type(self), type(key))) + if isinstance(key, str) and isinstance(Var.create(key), LiteralStringVar): + return self.__getattr__(key) return ObjectItemOperation.create(self, key).guess_type() # NoReturn is used here to catch when key value is Any @@ -260,12 +269,12 @@ def __getattr__(self, name: str) -> Var: if types.is_optional(var_type): var_type = get_args(var_type)[0] - fixed_type = var_type if isclass(var_type) else get_origin(var_type) + fixed_type = get_origin(var_type) or var_type if ( - (isclass(fixed_type) and not issubclass(fixed_type, Mapping)) + is_typeddict(fixed_type) + or (isclass(fixed_type) and not safe_issubclass(fixed_type, Mapping)) or (fixed_type in types.UnionTypes) - or is_typeddict(fixed_type) ): attribute_type = get_attribute_access_type(var_type, name) if attribute_type is None: diff --git a/tests/integration/test_var_operations.py b/tests/integration/test_var_operations.py index a5a74c9ee0..16885cd062 100644 --- a/tests/integration/test_var_operations.py +++ b/tests/integration/test_var_operations.py @@ -10,6 +10,8 @@ def VarOperations(): """App with var operations.""" + from typing import TypedDict + import reflex as rx from reflex.vars.base import LiteralVar from reflex.vars.sequence import ArrayVar @@ -17,6 +19,10 @@ def VarOperations(): class Object(rx.Base): name: str = "hello" + class Person(TypedDict): + name: str + age: int + class VarOperationState(rx.State): int_var1: rx.Field[int] = rx.field(10) int_var2: rx.Field[int] = rx.field(5) @@ -34,6 +40,9 @@ class VarOperationState(rx.State): dict1: rx.Field[dict[int, int]] = rx.field({1: 2}) dict2: rx.Field[dict[int, int]] = rx.field({3: 4}) html_str: rx.Field[str] = rx.field("