Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with test_module_method in attention_check.py #1

Open
KaelanDt opened this issue Dec 7, 2024 · 1 comment
Open

Issue with test_module_method in attention_check.py #1

KaelanDt opened this issue Dec 7, 2024 · 1 comment

Comments

@KaelanDt
Copy link

KaelanDt commented Dec 7, 2024

I tried to run the following cell in the attention tutorial notebook:

from attention.control_values.attention_checks import test_module_method

mha = MultiHeadAttention(c_in, c, N_head, attn_dim=attn_dim, gated=True)

test_module_method(mha, 'mha_prep_qkv', ('q', 'k', 'v'), ('q_prep', 'k_prep', 'v_prep'), control_folder, mha.prepare_qkv)

and it throws the following error:

File ~/Desktop/code-exercises/alphafold/tutorials/attention/control_values/attention_checks.py:182, in test_module_method(module, test_name, input_names, output_names, control_folder, method, include_batched, overwrite_results)
    [180](https://file+.vscode-resource.vscode-cdn.net/Users/kaelan/Desktop/code-exercises/alphafold/tutorials/attention/~/Desktop/code-exercises/alphafold/tutorials/attention/control_values/attention_checks.py:180)     #print(out.shape, expected_out.shape, out, expected_out)
    [181](https://file+.vscode-resource.vscode-cdn.net/Users/kaelan/Desktop/code-exercises/alphafold/tutorials/attention/~/Desktop/code-exercises/alphafold/tutorials/attention/control_values/attention_checks.py:181)     print(torch.linalg.norm(out - expected_out))
--> [182](https://file+.vscode-resource.vscode-cdn.net/Users/kaelan/Desktop/code-exercises/alphafold/tutorials/attention/~/Desktop/code-exercises/alphafold/tutorials/attention/control_values/attention_checks.py:182)     assert torch.allclose(out, expected_out), f'Problem with output {out_name} in test {test_name} in non-batched check.'
    [184](https://file+.vscode-resource.vscode-cdn.net/Users/kaelan/Desktop/code-exercises/alphafold/tutorials/attention/~/Desktop/code-exercises/alphafold/tutorials/attention/control_values/attention_checks.py:184) if include_batched:
    [185](https://file+.vscode-resource.vscode-cdn.net/Users/kaelan/Desktop/code-exercises/alphafold/tutorials/attention/~/Desktop/code-exercises/alphafold/tutorials/attention/control_values/attention_checks.py:185)     for out, out_file_name, out_name in zip(batched_out, out_file_names, output_names):

AssertionError: Problem with output q_prep in test mha_prep_qkv in non-batched check.

When I print the norm of the difference of the two quantities in the torch.allclose call causing the error, it is about 7e-6, so I believe this can be fixed by simply adding a tolerance. Happy to make a PR doing that if you agree this is the issue!

@kilianmandon
Copy link
Owner

kilianmandon commented Dec 10, 2024

Hi,

thanks for reporting the issue! I'm not super knowledgeable in cross-platform compatibility, do you think it's an expectable error? It seems weird to me that double computations should have so much offset, but I'm not sure how consistent PyTorch is over different platforms.

In any case, the error seems to be small enough not to lead to false positives, so I'd welcome a pull request!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants