Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NNsight should fail when an Unset proxy from a previous trace / future computation is used in a patching experiment #143

Open
Butanium opened this issue May 25, 2024 · 0 comments

Comments

@Butanium
Copy link
Contributor

This kind of silent failure can make nnsight very hard to debug:

import torch as th
from nnsight import LanguageModel
nn_model = LanguageModel("gpt2", device_map="cpu")

# The patching fails silently because hidden is not set
with nn_model.trace("a"):
    hidden = nn_model.transformer.h[0].output
with nn_model.trace("b"):
    nn_model.transformer.h[0].output = hidden
    corrupted_logits = nn_model.lm_head.output.save()

# The patching will work
with nn_model.trace("a"):
    hidden = nn_model.transformer.h[0].output.save()
with nn_model.trace("b"):
    nn_model.transformer.h[0].output = hidden
    corrupted_logits2 = nn_model.lm_head.output.save()

# The patching will fail silently because h[10].output has not been computed when h[0] is computed
with nn_model.trace("b"):
    nn_model.transformer.h[0].output = nn_model.transformer.h[10].output
    corrupted_logits3 = nn_model.lm_head.output.save()

with nn_model.trace("b"):
    clean_logits = nn_model.lm_head.output.save()
assert not th.allclose(clean_logits, corrupted_logits2), "this assert pass"
assert not th.allclose(clean_logits, corrupted_logits3), "this assert fails"
assert not th.allclose(clean_logits, corrupted_logits), "this assert fails"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant