Skip to content

Commit

Permalink
Update mlu-ops build logic to use static library
Browse files Browse the repository at this point in the history
  • Loading branch information
ClowDragon authored and fuwenguang committed Mar 25, 2024
1 parent 93025c4 commit daf46da
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 36 deletions.
4 changes: 2 additions & 2 deletions build.property
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
"version": "v1.1.0",
"official_version": "2.1",
"build_requires": {
"catch_113": ["1.18.0"],
"catch_210": ["24.01"],
"catch_113": ["1.19.0"],
"catch_210": ["1.19.0"],
"mluops": ["v1.0.0"]},
"args_combination": {
"wheel1": ["1.13", "3.10"],
Expand Down
16 changes: 16 additions & 0 deletions build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/bash

chmod -R 777 mlu-ops

pushd mlu-ops
source env.sh
./independent_build.sh --no_prepare --disable-gtest --enable-static
popd

if [ -d "./mmcv/lib" ]; then
echo "mmcv/lib directory already existed!"
else
echo "Creating mmcv/lib directory!"
mkdir mmcv/lib
fi
cp mlu-ops/build/lib/*.a* ./mmcv/lib
82 changes: 48 additions & 34 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,28 @@ def convert_to_pt_version(version_string):
mlu_unique_version = "+" + version_value + "+" + pt_version
return mluops_version[0], mlu_unique_version

def get_local_mluops_version():
neuware_home = os.environ.get('NEUWARE_HOME')
if not neuware_home:
print("NEUWARE_HOME environment variable not set.")
return ""

local_lib_path = os.path.join(neuware_home, "lib64")
try:
matched_files = [fn for fn in os.listdir(local_lib_path) if fn.startswith("libmluops")]
pattern = re.compile(r'\.so\.(\d+\.\d+\.\d+)$')
for filename in matched_files:
match = pattern.search(filename)
if match:
local_mluops_version = match.group(1)
print("Found local mluops_version:", local_mluops_version)
return local_mluops_version
print("No matching mlu-ops found.")
return ""
except OSError as e:
print(f"Error accessing directory: {e}")
return ""

def parse_requirements(fname='requirements/runtime.txt', with_version=True):
"""Parse the package dependencies listed in a requirements file but strips
specific versioning information.
Expand Down Expand Up @@ -292,6 +314,7 @@ def get_extensions():
from torch_mlu.utils.cpp_extension import MLUExtension

mmcv_mluops_version, _ = get_mlu_version()
local_mluops_version = get_local_mluops_version()
mlu_ops_path = os.getenv('MMCV_MLU_OPS_PATH')
if mlu_ops_path:
exists_mluops_version, _ = get_mlu_version()
Expand All @@ -315,24 +338,30 @@ def get_extensions():
'or rename or remove it.')
else:
if not os.path.exists('mlu-ops'):
import requests
mluops_url = 'https://github.com/Cambricon/mlu-ops/' + \
'archive/refs/tags/' + mmcv_mluops_version + '.zip'
req = requests.get(mluops_url)
with open('./mlu-ops.zip', 'wb') as f:
try:
f.write(req.content)
except Exception:
raise ImportError('failed to download mlu-ops')

from zipfile import BadZipFile, ZipFile
with ZipFile('./mlu-ops.zip', 'r') as archive:
try:
archive.extractall()
dir_name = archive.namelist()[0].split('/')[0]
os.rename(dir_name, 'mlu-ops')
except BadZipFile:
print('invalid mlu-ops.zip file')
if parse_version(local_mluops_version) >= parse_version(mmcv_mluops_version[1:]):
include_dirs.append(os.path.abspath(os.environ.get('NEUWARE_HOME') + '/include/'))
else:
import requests
mluops_url = 'https://github.com/Cambricon/mlu-ops/' + \
'archive/refs/tags/' + mmcv_mluops_version + '.zip'
req = requests.get(mluops_url)
with open('./mlu-ops.zip', 'wb') as f:
try:
f.write(req.content)
except Exception:
raise ImportError('failed to download mlu-ops')

from zipfile import BadZipFile, ZipFile
with ZipFile('./mlu-ops.zip', 'r') as archive:
try:
archive.extractall()
dir_name = archive.namelist()[0].split('/')[0]
os.rename(dir_name, 'mlu-ops')
except BadZipFile:
print('invalid mlu-ops.zip file')
os.system("bash build.sh")
extra_objects.append(os.path.abspath('./mmcv/lib/libmluops.a'))
include_dirs.append(os.path.abspath('./mlu-ops/'))
else:
exists_mluops_version, _ = get_mlu_version()
if exists_mluops_version != mmcv_mluops_version:
Expand All @@ -350,31 +379,16 @@ def get_extensions():
if parse_version(local_torch_version) < parse_version('2.3.0'):
define_macros += [('MMCV_WITH_TORCH_OLD', None)]
mlu_args = os.getenv('MMCV_MLU_ARGS', '-DNDEBUG ')
mluops_includes = []
mluops_includes.append(
'-I' + os.path.abspath('./mlu-ops/kernels'))
mluops_includes.append('-I' + os.path.abspath('./mlu-ops/'))
extra_compile_args['cncc'] = [mlu_args] + \
mluops_includes if mlu_args else mluops_includes
extra_compile_args['cxx'] += ['-fno-gnu-unique']
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \
glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \
glob.glob('./mmcv/ops/csrc/pytorch/mlu/*.cpp') + \
glob.glob(
'./mlu-ops/core/*.cpp', recursive=True) + \
glob.glob(
'./mlu-ops/core/*/*/*.cpp', recursive=True) + \
glob.glob(
'./mlu-ops/kernels/**/*.cpp', recursive=True) + \
glob.glob(
'./mlu-ops/kernels/**/*.mlu', recursive=True)
glob.glob('./mmcv/ops/csrc/pytorch/mlu/*.cpp')
extra_link_args = [
'-Wl,--whole-archive',
'-Wl,--no-whole-archive'
]
extension = MLUExtension
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
include_dirs.append(os.path.abspath('./mlu-ops/'))
elif (hasattr(torch.backends, 'mps')
and torch.backends.mps.is_available()) or os.getenv(
'FORCE_MPS', '0') == '1':
Expand Down

0 comments on commit daf46da

Please sign in to comment.