Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Half dtype and mixed precision training. #77

Open
wants to merge 9 commits into
base: master
Choose a base branch
from

Conversation

maskjp
Copy link

@maskjp maskjp commented Jul 27, 2021

Hi,

Thank you for this great library and torch-points3d.

I made some modifications to support mixed-precision training.

The major changes are as follows:

  • template kernel function in interpolate_gpu.cu sampling_gpu.cu and ball_query_gpu.cu.
  • Change AT_DISPATCH_FLOATING_TYPES to AT_DISPATCH_FLOATING_TYPES_AND_HALF;
  • Change atomicAdd to gpuAtomicAdd (in pytorch THCAtomics.cuh);
  • Add custom_fwd and custom_bwd in torch.autograd.Function to allow autocast;
  • fixed bug of sampling_gpu which the first element of idx output is always 0; (But I found that the output of GPU version and CPU version are not the same. I haven't fixed this.)

The modified version passed the tests in the test folder. I didn't see affection on full-precision training. And I tried to train the PointNet2 model in the torch-point3d library in a mixed-precision style, it works.

@nicolas-chaulet
Copy link
Member

Amazing!!! Thank you so much for contributing, this is a really needed feature. Tagging @CCInc so that he can take a look as well.

@clee-ai
Copy link
Member

clee-ai commented Jul 28, 2021

@maskjp Thanks for the contribution!

I'm curious which version of pytorch you compiled against, did they change the tensor namespace to torch at somepoint?

I'm guessing the models use ~50% the memory compared to full precision? Did you notice any training speed increase too?

@clee-ai
Copy link
Member

clee-ai commented Jul 28, 2021

@nicolas-chaulet we should add precommit.ci to this repo, so we can ensure consistent formatting on PRs, what do you think? Also, maybe we could add a pytorch version matrix to unittesting? Maybe 1.7.0 to latest?

Copy link
Member

@clee-ai clee-ai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code looks good at first glance! I will look at it closer later on, in the meantime can you clean up all the comments and install/run pre-commit for the code formatting?

int b = xyz.size(0);
int n = xyz.size(1);
int m = new_xyz.size(1);
torch::Tensor idx =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can use auto on all tensor types I think, to make it a little cleaner


// three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1),
// unknowns.DATA_PTR<float>(), knows.DATA_PTR<float>(),
// dist2.DATA_PTR<float>(), idx.DATA_PTR<int>());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can delete all the commented lines from the files I think

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @CCInc,

I removed the comments and made the changes.

@maskjp
Copy link
Author

maskjp commented Jul 28, 2021

I'm guessing the models use ~50% of the memory compared to full precision? Did you notice any training speed increase too?

Hi,@CCInc,

I used torch1.8.1+cu111 to compile it. About the tensor namespace, torch tensor has a higher level, we can also use at too. I just noticed that in chamfer_dist.cpp, cubic_feature_sampling.cpp the torch namespace are used but interpolate.cpp, sampling.cpp and ball_query.cpp used at tensor. I refer to the toturial of pytorch and torch_geomtrics and decide to use torch tensor.

Yes, the modification saves the memory, but I didn't see the training speed increase too. The speed also depends on the model architecture, ops, and io.

@nicolas-chaulet
Copy link
Member

Looks good to me, @CCInc could you please verify that the gpu tests pass on your machine? I don't have access to gpus anymore... Thanks! And yes to pre-commit ci!

@nicolas-chaulet nicolas-chaulet requested a review from clee-ai August 2, 2021 09:23
@clee-ai
Copy link
Member

clee-ai commented Aug 2, 2021

@maskjp I'm getting an issue with the testing, it seems like the cpu and gpu fps are not matching up?

FAIL: test_gpu (test.test_fps.TestFps)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/mnt/f/data/PC/torch-points-kernels/test/__init__.py", line 7, in wrapped_func
    return func(*args, **kwargs)
  File "/mnt/f/data/PC/torch-points-kernels/test/test_fps.py", line 35, in test_gpu
    torch.testing.assert_allclose(sorted_idx,sorted_idx_cpu)
  File "/home/chris/miniconda3/envs/tpk/lib/python3.7/site-packages/torch/testing/_core.py", line 270, in assert_allclose
    raise AssertionError(msg)
AssertionError: Found 27 different element(s) (out of 32), with the greatest difference of 63 (82 vs. 19) occuring at index (8, 1).

@maskjp
Copy link
Author

maskjp commented Aug 3, 2021

@CCInc , Yes, test_gpu in test_fps.py file is created by me. The original test didn't test the GPU version of fps. I found that the output of the CPU and GPU versions are different even before my modification.

@nicolas-chaulet
Copy link
Member

nicolas-chaulet commented Aug 3, 2021 via email

@clee-ai
Copy link
Member

clee-ai commented Aug 3, 2021

Thanks for clarifying! In addition to what Nicolas suggested, would you mind also add a test that verifies the functionality of cuda fps on its own, similar to test_simplecpu?

The rest of the PR works well!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants