diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 9b7fa1e894..5a6efa434e 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -1384,7 +1384,7 @@ def from_binary_subscr(provenance, *, new_output=False): output = Proxy("subscr") # name? collectify? else: output = p - if isinstance(idx, (int, str)): + if isinstance(idx, (int, str, Proxy)): if isinstance(idx, int): idx = int(idx) elif isinstance(idx, str): diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index 2e10944f16..345eabd921 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -2937,3 +2937,32 @@ def fn(x): return_bsym = trace.bound_symbols[-1] assert return_bsym.sym.id == thunder.prims.PrimIDs.RETURN assert return_bsym.output is None + + +def test_indexing_with_hashable_object(): + class HashableClass: + def __hash__(self): + return id(self) + + h = HashableClass() + d = {h: 1, 1: 0} + + def fn(): + return d[h] + + jfn = thunder.jit(fn) + assert jfn() == 1 + assert thunder.cache_misses(jfn) == 1 # Due to first compilation. + + # Call jfn with no changes + # this should be cache hit. + assert jfn() == 1 + assert thunder.cache_hits(jfn) == 1 + assert thunder.cache_misses(jfn) == 1 + + # Change the value of the captured dict. + # This should be a cache miss, verify that. + d[h] = 2 + assert jfn() == 2 # Verify that jfn now returns 2 + assert thunder.cache_hits(jfn) == 1 + assert thunder.cache_misses(jfn) == 2