Skip to content

Commit

Permalink
Fix compat issue
Browse files Browse the repository at this point in the history
  • Loading branch information
xia-mc committed Dec 13, 2024
1 parent 8995634 commit b65a6aa
Show file tree
Hide file tree
Showing 2 changed files with 5,946 additions and 6,883 deletions.
39 changes: 22 additions & 17 deletions scripts/GenerateAVX512Bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,7 @@
}
{arg_parsing}
{operation}
Py_RETURN_NONE;
#else
PyErr_SetString(PyExc_NotImplementedError, "AVX-512 is not supported on this architecture.");
return nullptr;
Expand Down Expand Up @@ -631,15 +628,13 @@ def getCallCode(func: SIMDFunc, constantRequire: list[tuple[int, int, int]]) ->
else:
code = f" {func.name}({args});"

isOuter = True
if len(constantRequire) > 0:
if len(constantRequire) > 4:
raise NotImplementedError(len(constantRequire))

for i, (argId, immediateMin, immediateMax) in enumerate(constantRequire):
if immediateMin == immediateMax:
code = code.replace(f"arg{argId}", str(immediateMin))
isOuter = False
continue

result = f"switch (arg{argId}) {{\n"
Expand All @@ -664,16 +659,6 @@ def getCallCode(func: SIMDFunc, constantRequire: list[tuple[int, int, int]]) ->

result += "}"

if isOuter:
result = f"""
#if defined(__clang__) || defined(__GNUC__)
{result}
#else
PyErr_SetString(PyExc_NotImplementedError, "AVX-512 is not supported on this architecture.");
return nullptr;
#endif
"""

return result
return code

Expand Down Expand Up @@ -729,12 +714,25 @@ def main():
if curImmediateGenerated != 1:
immediateGenerated += curImmediateGenerated

callCode = getCallCode(function, constantRequire)
if len(constantRequire) > 0:
argParseCode = f"#if defined(__clang__) || defined(__GNUC__)\n{argParseCode}"
callCode = (f"{formatCode(callCode)}\n\n"
f" Py_RETURN_NONE;\n"
f"#else\n"
f" PyErr_SetString(PyExc_NotImplementedError, \"Target C Method require immediate numbers, "
f"and this method is not supported in GCC/Clang now.\");\n"
f" return nullptr;\n"
f"#endif")
else:
callCode = f"{callCode}\n\n Py_RETURN_NONE;"

funcCode = funcCode.replace("{num_args}", str(num_args))
funcCode = funcCode.replace("{function_name}", function.name)
funcCode = funcCode.replace("{arg_parsing}", argParseCode)

# operation
funcCode = funcCode.replace("{operation}", getCallCode(function, constantRequire))
funcCode = funcCode.replace("{arg_parsing}", argParseCode)
funcCode = funcCode.replace("{operation}", callCode)
functionsGenerated += 1
function_def += funcCode

Expand Down Expand Up @@ -799,5 +797,12 @@ def main():
print(f"Generated '{RESULT_PYI_FILE}' with {size} bytes and {lines} lines.")


def formatCode(code: str) -> str:
result = ""
for line in code.split("\n"):
result += f" {line}\n"
return result


if __name__ == '__main__':
main()
Loading

0 comments on commit b65a6aa

Please sign in to comment.