Skip to content

Commit

Permalink
[jit] support unpacking hashable object for indexing (#1290)
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 authored Oct 11, 2024
1 parent cc7335c commit c9fe11b
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
2 changes: 1 addition & 1 deletion thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -1384,7 +1384,7 @@ def from_binary_subscr(provenance, *, new_output=False):
output = Proxy("subscr") # name? collectify?
else:
output = p
if isinstance(idx, (int, str)):
if isinstance(idx, (int, str, Proxy)):
if isinstance(idx, int):
idx = int(idx)
elif isinstance(idx, str):
Expand Down
29 changes: 29 additions & 0 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2937,3 +2937,32 @@ def fn(x):
return_bsym = trace.bound_symbols[-1]
assert return_bsym.sym.id == thunder.prims.PrimIDs.RETURN
assert return_bsym.output is None


def test_indexing_with_hashable_object():
class HashableClass:
def __hash__(self):
return id(self)

h = HashableClass()
d = {h: 1, 1: 0}

def fn():
return d[h]

jfn = thunder.jit(fn)
assert jfn() == 1
assert thunder.cache_misses(jfn) == 1 # Due to first compilation.

# Call jfn with no changes
# this should be cache hit.
assert jfn() == 1
assert thunder.cache_hits(jfn) == 1
assert thunder.cache_misses(jfn) == 1

# Change the value of the captured dict.
# This should be a cache miss, verify that.
d[h] = 2
assert jfn() == 2 # Verify that jfn now returns 2
assert thunder.cache_hits(jfn) == 1
assert thunder.cache_misses(jfn) == 2

0 comments on commit c9fe11b

Please sign in to comment.