Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ewald calculator is extremely slow due to for loop #80

Closed
sirmarcel opened this issue Oct 15, 2024 · 12 comments · Fixed by #84
Closed

Ewald calculator is extremely slow due to for loop #80

sirmarcel opened this issue Oct 15, 2024 · 12 comments · Fixed by #84

Comments

@sirmarcel
Copy link
Contributor

sirmarcel commented Oct 15, 2024

Heya, I noticed an errant for loop in the Ewald calculator. Here:

energy = torch.zeros_like(charges)
for i in range(num_atoms):
    energy[i] += torch.sum(
        G * cos_all[:, i] * cos_summed, dim=sum_idx - 1
    ) + torch.sum(G * sin_all[:, i] * sin_summed, dim=sum_idx - 1)

The fix is simple, simply replace the above with

def potential_function(c, s):
    return torch.sum(G * c * cos_summed + G * s * sin_summed, dim=sum_idx - 1)

energy = torch.vmap(potential_function, in_dims=(-1, -1))(cos_all, sin_all)

This gives a speedup of about 10x to 1000x and passes all the correctness tests. In fact, Ewald outperforms PME in all my tests (EDIT: not quite) with this fix.

Unfortunately, it doesn't work with TorchScript:


torch.jit.frontend.UnsupportedNodeError: function definitions aren't supported:
  File "/work/cosmo-erc/langer/kuma/libs/torch-pme/src/torchpme/calculators/ewald.py", line 125
        #     ) + torch.sum(G * sin_all[:, i] * sin_summed, dim=sum_idx - 1)
    
        def potential_function(c, s):
        ~~~ <--- HERE
            return torch.sum(G * c * cos_summed + G * s * sin_summed, dim=sum_idx - 1)
    

Anyone want to get a 1000x speedup to their name?

@E-Rum
Copy link
Contributor

E-Rum commented Oct 15, 2024

What is the maximum number of atoms in the system you used for testing?

@PicoCentauri
Copy link
Contributor

1024

@sirmarcel
Copy link
Contributor Author

sirmarcel commented Oct 15, 2024

Actually, I didn't scroll all the way down, there are some cases, surprisingly on the small side, where PME edges out a win. I'll run bigger cases soon.

For posterity, here's the benchmark script output for this change (on a L40s) (obviously, i disabled torchscript)

results:
  1024_light_cpu:
    E: 0.028006192687826115
    PME: 0.05425682137502008
  1024_light_cuda:
    E: 0.0021106005006004125
    PME: 0.004031811624372494
  1024_tight_cpu:
    E: 0.24040848081131116
    PME: 0.3489146663760039
  1024_tight_cuda:
    E: 0.003573990061340737
    PME: 0.006905024374646018
  128_light_cpu:
    E: 0.002079558124023606
    PME: 0.046015231499040965
  128_light_cuda:
    E: 0.001985459062780137
    PME: 0.004574535623760312
  128_tight_cpu:
    E: 0.014728610187376034
    PME: 0.33456345718695957
  128_tight_cuda:
    E: 0.0019395997496758355
    PME: 0.0066749854377121665
system:
  cpu: x86_64
  gpu: NVIDIA L40S
  node: kl003
  platform: Linux-5.14.0-70.30.1.el9_0.x86_64-x86_64-with-glibc2.34
timestamp: '2024-10-15T15:24:30.801431'
version:
  torch: 2.4.1+cu121
  torch-pme-commit: not found
  torch-pme-status: not found

versus

results:
  1024_light_cpu:
    E: 0.5565416831886978
    E_MT: 0.5542255340005795
    PME: 0.05269402262638323
    PME_MT: 0.058186623187793884
  1024_light_cuda:
    E: 0.21703302312380401
    E_MT: 0.21828696675038373
    PME: 0.0032142497493623523
    PME_MT: 0.003692456311910064
  1024_tight_cpu:
    E: 15.574867601000733
    E_MT: 11.684244794687402
    PME: 0.32599988918809686
    PME_MT: 0.39711523550067795
  1024_tight_cuda:
    E: 0.5313887392512697
    E_MT: 0.531285271186789
    PME: 0.006680479064016254
    PME_MT: 0.006621209626246127
  128_light_cpu:
    E: 0.022671775999697275
    E_MT: 0.02297365800041007
    PME: 0.04760915656152065
    PME_MT: 0.048393513749033445
  128_light_cuda:
    E: 0.028179472687043017
    E_MT: 0.02838066218646418
    PME: 0.003127433312329231
    PME_MT: 0.003241197062379797
  128_tight_cpu:
    E: 0.0762400030016579
    E_MT: 0.07706811800017022
    PME: 0.3605894702486694
    PME_MT: 0.36393232043701573
  128_tight_cuda:
    E: 0.027397436311730416
    E_MT: 0.027304522625854588
    PME: 0.005997346625008504
    PME_MT: 0.006475141062765033
system:
  cpu: x86_64
  gpu: NVIDIA L40S
  node: kl001
  platform: Linux-5.14.0-70.30.1.el9_0.x86_64-x86_64-with-glibc2.34
timestamp: '2024-10-15T15:08:35.999047'
version:
  torch: 2.4.1+cu121
  torch-pme-commit: not found
  torch-pme-status: not found

@sirmarcel
Copy link
Contributor Author

For the script, see #81

@kvhuguenin
Copy link
Contributor

Are the numbers the total time (per frame) or time per atom (per frame)? I would have guessed the former based on the script, but just because for PME, the required cost seems to be the same for N=128 and N=1024. Do we have such a huge overhead?

@sirmarcel
Copy link
Contributor Author

@kvhuguenin Let's move this discussion to #81 :)

@sirmarcel
Copy link
Contributor Author

sirmarcel commented Oct 15, 2024

Here's up to ~11k atoms

Running 128_tight_cuda...
{'PME': 0.007009895374721964, 'E': 0.0021171706248424016}

Running 128_light_cuda...
{'PME': 0.004027711438538972, 'E': 0.0019485699376673438}

Running 1024_tight_cuda...
{'PME': 0.00684068312511954, 'E': 0.0036730438750964822}

Running 1024_light_cuda...
{'PME': 0.00403638187526667, 'E': 0.002026798438237165}

Running 3456_tight_cuda...
{'PME': 0.010416301751320134, 'E': 0.01812379137481912}

Running 3456_light_cuda...
{'PME': 0.004015727250589407, 'E': 0.0020725228750961833}

Running 8192_tight_cuda...
{'PME': 0.03968679750141746, 'E': 0.06595321762506501}

Running 8192_light_cuda...
{'PME': 0.005954447125986917, 'E': 0.006124208999608527}

Running 11664_tight_cuda...
{'PME': 0.07403455074927479, 'E': 0.1142072080001526}

Running 11664_light_cuda...
{'PME': 0.00985147124993091, 'E': 0.011776811812524102}

... and the old version, until i got bored with waiting

Running 128_tight_cuda...
{'PME': 0.0065868528126884485, 'E': 0.03286850287440757}

Running 128_light_cuda...
{'PME': 0.003949769563405425, 'E': 0.032670553626303445}

Running 1024_tight_cuda...
{'PME': 0.006743700874721981, 'E': 0.5455628992494894}

Running 1024_light_cuda...
{'PME': 0.004013268562630401, 'E': 0.2591304460001993}

@ceriottm
Copy link
Contributor

Can you try this

        # Actual computation of trigonometric factors
        cos_all = torch.cos(trig_args)
        sin_all = torch.sin(trig_args)
        cos_summed_G = torch.sum(cos_all * charges_reshaped, dim=sum_idx)*G 
        sin_summed_G = torch.sum(sin_all * charges_reshaped, dim=sum_idx)*G

        energy = (cos_summed_G@cos_all+sin_summed_G@sin_all).T
        energy /= torch.abs(cell.det())

This should compute the same stuff without vmap and should be torchscriptable

@ceriottm
Copy link
Contributor

10% faster, couldn't resist

        sincos_all = torch.stack([torch.cos(trig_args), torch.sin(trig_args)])
        sincos_summed_G = torch.sum(sincos_all * charges_reshaped, dim=sum_idx)*G
        energy = sincos_summed_G.reshape(1,-1) @ sincos_all.reshape(-1, sincos_all.shape[-1])
        energy = energy.T / torch.abs(cell.det())

@sirmarcel
Copy link
Contributor Author

10% faster, couldn't resist

        sincos_all = torch.stack([torch.cos(trig_args), torch.sin(trig_args)])
        sincos_summed_G = torch.sum(sincos_all * charges_reshaped, dim=sum_idx)*G
        energy = sincos_summed_G.reshape(1,-1) @ sincos_all.reshape(-1, sincos_all.shape[-1])
        energy = energy.T / torch.abs(cell.det())

This doesn't work,

FAILED tests/calculators/test_workflow.py::TestWorkflow::test_dtype_device[EwaldCalculator-params1] - RuntimeError: output with shape [1, 1] doesn't match the broadcast shape [1, 2]

@sirmarcel
Copy link
Contributor Author

sirmarcel commented Oct 16, 2024

# Actual computation of trigonometric factors
cos_all = torch.cos(trig_args)
sin_all = torch.sin(trig_args)
cos_summed_G = torch.sum(cos_all * charges_reshaped, dim=sum_idx)*G
sin_summed_G = torch.sum(sin_all * charges_reshaped, dim=sum_idx)*G

    energy = (cos_summed_G@cos_all+sin_summed_G@sin_all).T
    energy /= torch.abs(cell.det())

this passes all the tests!

... and is faster than vmap

Running 128_tight_cuda...
{'PME': 0.006494242312328424, 'E': 0.0014419457511394285}
  
Running 128_light_cuda...
{'PME': 0.002784314343443839, 'E': 0.0014653211237600772}
  
Running 1024_tight_cuda...
{'PME': 0.006684518875772483, 'E': 0.0024445456874673255}
  
Running 1024_light_cuda...
{'PME': 0.003159262781991856, 'E': 0.0017616292825550772}
  
Running 3456_tight_cuda...
{'PME': 0.010124633625309798, 'E': 0.009587466782249976}
  
Running 3456_light_cuda...
{'PME': 0.00336796584269905, 'E': 0.0018633734998729778}

@sirmarcel
Copy link
Contributor Author

@kvhuguenin can you please check that @ceriottm 's modification does the right thing? I'm not so sure I really understand the charge reshaping stuff. If we're sure, I'll make a PR to fix this.

sirmarcel added a commit that referenced this issue Oct 18, 2024
sirmarcel added a commit that referenced this issue Oct 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants