Skip to content

Commit

Permalink
Make global soln indices deterministic (#1599)
Browse files Browse the repository at this point in the history
* Make global soln index values deterministic
  • Loading branch information
b-shi authored Jan 30, 2025
1 parent 675d0e3 commit d9063a8
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 25 deletions.
2 changes: 1 addition & 1 deletion tensilelite/Tensile/ClientWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def main(config, cxxCompiler: str, cCompiler: str):

clientParametersPaths = []
for logicFileName in logicFiles:
(scheduleName, _, problemType, _, exactLogic, newLibrary, _) \
(scheduleName, _, problemType, _, exactLogic, newLibrary) \
= LibraryIO.parseLibraryLogicFile(logicFileName, cxxCompiler)
if problemType["DataType"].isHalf():
enableHalf = True
Expand Down
8 changes: 5 additions & 3 deletions tensilelite/Tensile/Contractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,10 +661,10 @@ class Solution:

@classmethod
def FromSolutionStruct(cls, solution, cxxCompiler: str):
return cls.FromOriginalState(solution._state, cxxCompiler)
return cls.FromOriginalState(solution._state, cxxCompiler, solution.srcName)

@classmethod
def FromOriginalState(cls, d, cxxCompiler, deviceInfo=None):
def FromOriginalState(cls, d, cxxCompiler, srcName = "", deviceInfo=None):
rv = cls()


Expand Down Expand Up @@ -711,7 +711,8 @@ def FromOriginalState(cls, d, cxxCompiler, deviceInfo=None):
d['CUCount'] = None

rv.hardwarePredicate = Hardware.HardwarePredicate.FromHardware(d['ISA'], d['CUCount'])
rv.originalSolution = OriginalSolution(d, cxxCompiler)
rv.originalSolution = OriginalSolution(d, cxxCompiler, srcName)
rv.srcName = srcName

return rv

Expand All @@ -729,6 +730,7 @@ def __init__(self, **kwargs):
self.libraryLogicIndex = {}
self.index = None
self.ideals = {}
self.srcName = ""

for key, value in kwargs:
if key not in Solution.StateKeys and key not in Solution.HiddenKeys:
Expand Down
8 changes: 3 additions & 5 deletions tensilelite/Tensile/LibraryIO.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def parseSolutionsData(data, srcFile, cxxCompiler):
# force redo the deriving of parameters, make sure old version logic yamls can be validated
solutionState["AssignedProblemIndependentDerivedParameters"] = False
solutionState["AssignedDerivedParameters"] = False
solutionObject = Solution(solutionState, cxxCompiler)
solutionObject = Solution(solutionState, cxxCompiler, srcFile)
solutions.append(solutionObject)
problemType = solutions[0]["ProblemType"]
problemSizes = ProblemSizes(problemType, problemSizesConfig)
Expand All @@ -240,8 +240,6 @@ class LibraryLogic(NamedTuple):
solutions: list
exactLogic: list
library: SolutionLibrary.MasterSolutionLibrary
srcFile: str


def parseLibraryLogicFile(filename, cxxCompiler, archs=None):
"""Wrapper function to read and parse a library logic file."""
Expand Down Expand Up @@ -293,7 +291,7 @@ def solutionStateToSolution(solutionState, cxxCompiler) -> Solution:
# The ActivationType setting in YAML is meaningless in customKernel case.
# Therefore, we override the customKernel setting with the ActivationType value from ProblemType to avoid false alarms during subsequent problemType checks.
solutionState["ProblemType"]["ActivationType"] = problemType["ActivationType"]
solutionObject = Solution(solutionState, cxxCompiler)
solutionObject = Solution(solutionState, cxxCompiler, srcFile)
solutionProblemType = solutionObject["ProblemType"]
if problemType != solutionProblemType:
# find the mismatched items in ProblemType
Expand All @@ -310,7 +308,7 @@ def solutionStateToSolution(solutionState, cxxCompiler) -> Solution:
newLibrary, _ = SolutionLibrary.MasterSolutionLibrary.FromOriginalState(data, solutions, cxxCompiler)

return LibraryLogic(data["ScheduleName"], data["ArchitectureName"], problemType, solutions, \
data.get("ExactLogic"), newLibrary, srcFile)
data.get("ExactLogic"), newLibrary)


def parseLibraryLogicList(data, srcFile="?"):
Expand Down
3 changes: 2 additions & 1 deletion tensilelite/Tensile/SolutionStructs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,9 +1055,10 @@ def isExtractableIndex(ks, index, tc='x'):
class Solution(collections.abc.Mapping):

########################################
def __init__(self, config, cxxCompiler: str):
def __init__(self, config, cxxCompiler: str, srcName: str = ""):
self._name = None
self.cxxCompiler = cxxCompiler
self.srcName = srcName
config = config

self._state = {}
Expand Down
47 changes: 32 additions & 15 deletions tensilelite/Tensile/TensileCreateLibrary/Run.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def writeAssembly(asmPath: Union[Path, str], result: KernelCodeGenResult):
with open(path, "w", encoding="utf-8") as f:
f.write(result.src)
del result # result.src is very large so let gc know to clean up asap

return path, isa, wfsize


Expand All @@ -144,15 +144,15 @@ def writeHelpers(outputPath, kernelHelperObjs, KERNEL_HELPER_FILENAME_CPP, KERNE
kernelHeaderFilename = os.path.join(os.path.normcase(outputPath), KERNEL_HELPER_FILENAME_H)

with open(kernelHeaderFilename, "w", encoding="utf-8") as kernelHeaderFile, \
open(kernelSourceFilename, "w", encoding="utf-8") as kernelSourceFile:
open(kernelSourceFilename, "w", encoding="utf-8") as kernelSourceFile:
kernelSourceFile.write(CHeader)
kernelHeaderFile.write(CHeader)
kernelSourceFile.write("#include \"Kernels.h\"\n")
kernelHeaderFile.write("#pragma once\n")
if globalParameters["RuntimeLanguage"] == "HIP":
kernelHeaderFile.write("#include <hip/hip_runtime.h>\n")
kernelHeaderFile.write("#include <hip/hip_ext.h>\n\n")
kernelHeaderFile.write("#include \"KernelHeader.h\"\n\n")
kernelHeaderFile.write("#include \"KernelHeader.h\"\n\n")
HeaderText = ""
for ko in kernelHelperObjs:
kernelName = ko.getKernelName()
Expand Down Expand Up @@ -199,7 +199,7 @@ def assemble(ret):

writeHelpers(outputPath, kernelHelperObjs, KERNEL_HELPER_FILENAME_CPP, KERNEL_HELPER_FILENAME_H)
srcKernelFile = Path(outputPath) / "Kernels.cpp"

if not generateSourcesAndExit:
codeObjectFiles += buildAssemblyCodeObjectFiles(asmToolchain, asmKernels, kernelWriterAssembly, outputPath, compress)
buildSourceCodeObjectFile(srcToolchain, outputPath, fromTensile, srcKernelFile)
Expand Down Expand Up @@ -319,7 +319,7 @@ def generateLogicDataAndSolutions(logicFiles, args, cxxCompiler):
solutions = []
masterLibraries = {}
nextSolIndex = 0
matchTable = {}

fIter = zip(logicFiles, itertools.repeat(cxxCompiler), itertools.repeat(archs))

def libraryIter(lib: MasterSolutionLibrary):
Expand All @@ -331,7 +331,7 @@ def libraryIter(lib: MasterSolutionLibrary):
yield from libraryIter(lazyLib)

for library in ParallelMap2(LibraryIO.parseLibraryLogicFile, fIter, "Loading Logics...", return_as="generator_unordered"):
_, architectureName, _, _, _, newLibrary, srcFile = library
_, architectureName, _, _, _, newLibrary = library

if architectureName == "":
continue
Expand All @@ -341,11 +341,31 @@ def libraryIter(lib: MasterSolutionLibrary):
else:
masterLibraries[architectureName] = newLibrary
masterLibraries[architectureName].version = args["CodeObjectVersion"]

if args["GenSolTable"]:
# Match yaml file solutions to solution index
for localIdx, _, s in libraryIter(newLibrary):
matchTable[s.index] = [srcFile, localIdx]

# Sort masterLibraries to make global soln index values deterministic
solnReIndex=0
masterLibraries = dict(sorted(masterLibraries.items()))
for k,v in masterLibraries.items():
for _, masterLibrary in masterLibraries.items():
for _, sol in masterLibrary.solutions.items():
sol.index = solnReIndex
solnReIndex += 1
# Sort masterLibrary to make global soln index values deterministic
masterLibrary.lazyLibraries = dict(sorted(masterLibrary.lazyLibraries.items()))
for name, lib in masterLibrary.lazyLibraries.items():
# Sort solns by the lib logic file they were generated from
lib.solutions = {k: lib.solutions[k] for k in sorted(lib.solutions, key = lambda idx: lib.solutions[idx].srcName )}
for _, sol in lib.solutions.items():
sol.index = solnReIndex
solnReIndex += 1

if args["GenSolTable"]:
matchTable = {}
# Match yaml file solutions to solution index
for _,masterLibrary in masterLibraries.items():
for localIdx, _, s in libraryIter(masterLibrary):
matchTable[s.index] = [s.srcName, localIdx]
LibraryIO.write("MatchTable", matchTable)

if "fallback" in masterLibraries.keys():
for key, value in masterLibraries.items():
Expand All @@ -363,9 +383,6 @@ def libraryIter(lib: MasterSolutionLibrary):
# remove duplicates while preserving order
solutions = dict.fromkeys(solutions).keys()

if args["GenSolTable"]:
LibraryIO.write("MatchTable", matchTable)

return solutions, masterLibraries


Expand Down Expand Up @@ -450,7 +467,7 @@ def validLogicFile(p: Path):

copyStaticFiles(arguments["OutputPath"])

numKernels = writeSolutionsAndKernelsTCL(arguments["OutputPath"], asmToolchain, srcToolchain, kernels,
numKernels = writeSolutionsAndKernelsTCL(arguments["OutputPath"], asmToolchain, srcToolchain, kernels,
kernelHelperObjs, kernelWriterAssembly, compress=arguments["UseCompression"])

archs = [getGfxName(arch) for arch in globalParameters['SupportedISA'] \
Expand Down

0 comments on commit d9063a8

Please sign in to comment.