-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathsetup.py
64 lines (55 loc) · 2.16 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import subprocess
from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext
class BuildLibPairD3(build_ext):
def build_extension(self, ext):
"""Hack: We override the build_extension method to compile the CUDA code.
Using torch.utils.cpp_extension.CUDAExtension is a better solution,
but it introduces a dependency in pip's build-isolation environment.
To avoid this, we manually compile the CUDA code with nvcc.
"""
if not ext.name == 'sevenn.libpaird3':
super().build_extension(ext)
return
try:
subprocess.run(
['nvcc', '--version'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=True
)
print('CUDA is installed. Starting compilation of libpaird3.')
except FileNotFoundError:
print(
'CUDA is not installed or nvcc is not available.'
'Skipping compilation of libpaird3.'
)
return
target_path = self.get_ext_fullpath(ext.name)
compile = [
'nvcc',
'-o', f'{target_path}',
'-shared',
'-fmad=false',
'-O3',
'--expt-relaxed-constexpr',
'sevenn/pair_e3gnn/pair_d3_for_ase.cu',
'-Xcompiler', '-fPIC', '-lcudart',
'-gencode', 'arch=compute_61,code=sm_61',
'-gencode', 'arch=compute_70,code=sm_70',
'-gencode', 'arch=compute_75,code=sm_75',
'-gencode', 'arch=compute_80,code=sm_80',
'-gencode', 'arch=compute_86,code=sm_86',
'-gencode', 'arch=compute_89,code=sm_89',
'-gencode', 'arch=compute_90,code=sm_90',
] # you can add more architectures here
try:
subprocess.run(compile, check=True)
print('libpaird3.so compiled successfully.')
except subprocess.CalledProcessError as e:
print(f'Failed to compile libpaird3.so: {e}')
return
setup(
ext_modules=[Extension('sevenn.libpaird3', sources=[])],
cmdclass={'build_ext': BuildLibPairD3},
)