Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

thunder may treat global (maybe nonlocal) value as constant in computation trace without a check in prologue #1464

Open
kshitij12345 opened this issue Nov 22, 2024 · 1 comment
Labels
bug Something isn't working jit

Comments

@kshitij12345
Copy link
Collaborator

import torch
import thunder
from contextvars import ContextVar

_compile_data = ContextVar("compile_data", default=1)

def fn(x):
    v = _compile_data.get()
    return x + v

jfn = thunder.jit(fn)
o = jfn(torch.ones(3,))
print(o)  # tensor([2., 2., 2.])

_compile_data.set((2,))
o = jfn(torch.ones(3,))
print(o)  # tensor([2., 2., 2.]) (should be tensor([3., 3., 3.]))

print(thunder.last_prologue_traces(jfn)[-1])
print(thunder.last_traces(jfn)[-1])

Prologue Trace

@torch.no_grad()
@no_autocast
def prologue(*args, **kwargs):
  # args: "Any"
  check_len(args, 1)
    # prims.check_len(args, 1)
  # kwargs: "Any"
  check_len(kwargs, 0)
    # prims.check_len(kwargs, 0)
  x: "cpu f32[3]" = args[0]
  check_tensor_metadata(x, (3,), 'cpu', torch.float32, False)
    # prims.check_tensor_shape_and_metadata(x, (3,), 'cpu', torch.float32, False)
  cache_info: "Any" = thunder._get_cache_info()
  cache_info_default_dtype: "<class 'torch.dtype'>" = cache_info['default_dtype']
  check_literal_like(cache_info_default_dtype, torch.float32)
    # prims.check_literal_like(cache_info_default_dtype, torch.float32)
  cache_info_default_device: "<class 'torch.device'>" = cache_info['default_device']
  check_literal_like(cache_info_default_device, torch.device("cpu"))
    # prims.check_literal_like(cache_info_default_device, torch.device("cpu"))
  cache_info_is_autocast_enabled: "bool False" = cache_info['is_autocast_enabled']
  check_number_type_and_value(cache_info_is_autocast_enabled, False)
    # prims.check_number_type_and_value(cache_info_is_autocast_enabled, False)
  cache_info_no_grad_sync: "bool False" = cache_info['no_grad_sync']
  check_number_type_and_value(cache_info_no_grad_sync, False)
    # prims.check_number_type_and_value(cache_info_no_grad_sync, False)
  cache_info_alias_tensor_indices: "str" = cache_info['alias_tensor_indices']
  check_string_value(cache_info_alias_tensor_indices, '')
    # prims.check_string_value(cache_info_alias_tensor_indices, '')
  cache_info_is_grad_enabled: "bool True" = cache_info['is_grad_enabled']
  check_number_type_and_value(cache_info_is_grad_enabled, True)
    # prims.check_number_type_and_value(cache_info_is_grad_enabled, True)
  return ((x,), ())

Computation Trace

@torch.no_grad()
@no_autocast
def computation(x):
  # x: "cpu f32[3]"
  t0 = torch.add(x, 1, alpha=1)  # t0: "cpu f32[3]"
    # t0 = ltorch.add(x, 1, alpha=1)  # t0: "cpu f32[3]"
      # _ = prims.convert_element_type(1, float)
      # t0 = prims.add(x, 1.0)  # t0: "cpu f32[3]"
  return t0
@kshitij12345 kshitij12345 changed the title thunder may treat global (maybe non-local) value as constant in computation trace without a check in prologue thunder may treat global (maybe nonlocal) value as constant in computation trace without a check in prologue Nov 22, 2024
@kshitij12345 kshitij12345 added bug Something isn't working jit labels Nov 22, 2024
@kshitij12345
Copy link
Collaborator Author

Relevant Conversation: #1458 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working jit
Projects
None yet
Development

No branches or pull requests

1 participant