Skip to content

Commit

Permalink
Refactor / tidy-up of setup.py
Browse files Browse the repository at this point in the history
  • Loading branch information
gmarkall committed Jul 25, 2024
1 parent 0cf4d1f commit ac859f8
Showing 1 changed file with 14 additions and 19 deletions.
33 changes: 14 additions & 19 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
import logging
import pathlib

from setuptools import setup
from setuptools.command.build_py import build_py
from setuptools.command.editable_wheel import editable_wheel, _TopLevelFinder

REDIRECTOR_PTH = "_numba_cuda_redirector.pth"
REDIRECTOR_PY = "_numba_cuda_redirector.py"
SITE_PACKAGES = pathlib.Path("site-packages")


# Adapted from https://stackoverflow.com/a/71137790
class build_py_with_redirector(build_py): # noqa: N801
Expand All @@ -17,26 +20,22 @@ def copy_redirector_file(self, source, destination="."):

def run(self):
super().run()
site_packages = pathlib.Path("site-packages")
self.copy_redirector_file(site_packages / "_numba_cuda_redirector.pth")
self.copy_redirector_file(site_packages / "_numba_cuda_redirector.py")
self.copy_redirector_file(SITE_PACKAGES / REDIRECTOR_PTH)
self.copy_redirector_file(SITE_PACKAGES / REDIRECTOR_PY)

def get_source_files(self):
src = super().get_source_files()
site_packages = pathlib.Path("site-packages")
src.extend([
str(site_packages / "_numba_cuda_redirector.pth"),
str(site_packages / "_numba_cuda_redirector.py"),
str(SITE_PACKAGES / REDIRECTOR_PTH),
str(SITE_PACKAGES / REDIRECTOR_PY),
])
return src

def get_output_mapping(self):
mapping = super().get_output_mapping()
build_lib = pathlib.Path(self.build_lib)
mapping[str(build_lib / "_numba_cuda_redirector.pth")] = \
"_numba_cuda_redirector.pth"
mapping[str(build_lib / "_numba_cuda_redirector.py")] = \
"_numba_cuda_redirector.py"
mapping[str(build_lib / REDIRECTOR_PTH)] = REDIRECTOR_PTH
mapping[str(build_lib / REDIRECTOR_PY)] = REDIRECTOR_PY
return mapping


Expand All @@ -47,15 +46,11 @@ def get_implementation(self):
for item in super().get_implementation():
yield item

site_packages = pathlib.Path("site-packages")
pth_file = "_numba_cuda_redirector.pth"
py_file = "_numba_cuda_redirector.py"

with open(site_packages / pth_file) as f:
yield (pth_file, f.read())
with open(SITE_PACKAGES / REDIRECTOR_PTH) as f:
yield (REDIRECTOR_PTH, f.read())

with open(site_packages / py_file) as f:
yield (py_file, f.read())
with open(SITE_PACKAGES / REDIRECTOR_PY) as f:
yield (REDIRECTOR_PY, f.read())


class editable_wheel_with_redirector(editable_wheel):
Expand Down

0 comments on commit ac859f8

Please sign in to comment.