diff --git a/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp b/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp index 02cb2ffbc708..e98d51765cac 100644 --- a/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp +++ b/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp @@ -146,10 +146,12 @@ class CUDATargetBackend final : public TargetBackend { llvmModule->setDataLayout(targetMachine->createDataLayout()); - std::string targetISA = translateModuleToISA(*llvmModule, *targetMachine); + FlatbufferBuilder builder; + iree_CUDAExecutableDef_start_as_root(builder); + // Serialize cuda kernel into the binary that we will embed in the // final flatbuffer. - FlatbufferBuilder builder; + std::string targetISA = translateModuleToISA(*llvmModule, *targetMachine); auto ptxCudeRef = flatbuffers_uint8_vec_create( builder, reinterpret_cast(targetISA.c_str()), targetISA.size()); @@ -168,7 +170,6 @@ class CUDATargetBackend final : public TargetBackend { } auto blockSizesRef = iree_CUDABlockSizeDef_vec_end(builder); - iree_CUDAExecutableDef_start_as_root(builder); iree_CUDAExecutableDef_entry_points_add(builder, entryPointsRef); iree_CUDAExecutableDef_block_sizes_add(builder, blockSizesRef); iree_CUDAExecutableDef_ptx_image_add(builder, ptxCudeRef); diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp index 38cab2dda692..97ab76d13ee4 100644 --- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp +++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp @@ -305,9 +305,11 @@ class LLVMAOTTargetBackend final : public TargetBackend { linkArtifacts.keepAllFiles(); } + FlatbufferBuilder builder; + iree_DyLibExecutableDef_start_as_root(builder); + // Embed debug symbols at the end of the flatbuffer by adding first in the // bottoms-up builder. - FlatbufferBuilder builder; flatbuffers_uint8_vec_ref_t debugDatabaseRef = 0; flatbuffers_string_ref_t debugDatabaseFilenameRef = 0; if (options_.debugSymbols && linkArtifacts.debugFile.outputFile) { @@ -328,7 +330,6 @@ class LLVMAOTTargetBackend final : public TargetBackend { << linkArtifacts.libraryFile.path; } - iree_DyLibExecutableDef_start_as_root(builder); iree_DyLibExecutableDef_library_embedded_add(builder, libraryEmbeddedRef); iree_DyLibExecutableDef_debug_database_filename_add( builder, debugDatabaseFilenameRef); diff --git a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp index a9e5fe879c01..3bb9b7fafdf9 100644 --- a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp +++ b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp @@ -121,6 +121,7 @@ class MetalSPIRVTargetBackend : public SPIRVTargetBackend { // 4. Pack the MTLLibrary and metadata into a flatbuffer. FlatbufferBuilder builder; + iree_MetalExecutableDef_start_as_root(builder); auto shaderSourcesRef = builder.createStringVec(llvm::map_range( mslShaders, [&](const MetalShader &shader) { return shader.source; })); @@ -135,7 +136,6 @@ class MetalSPIRVTargetBackend : public SPIRVTargetBackend { auto entryPointNamesRef = builder.createStringVec(entryPointNames); - iree_MetalExecutableDef_start_as_root(builder); iree_MetalExecutableDef_entry_points_add(builder, entryPointNamesRef); iree_MetalExecutableDef_threadgroup_sizes_add(builder, threadgroupSizesRef); iree_MetalExecutableDef_shader_sources_add(builder, shaderSourcesRef); diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp b/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp index 8e00d15b2365..6cd7fc550f02 100644 --- a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp +++ b/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp @@ -102,8 +102,10 @@ class VMLATargetBackend final : public TargetBackend { LogicalResult serializeExecutable(IREE::HAL::ExecutableTargetOp targetOp, OpBuilder &executableBuilder) override { - // Serialize the VM module to bytes directly into a flatbuffer. FlatbufferBuilder builder; + iree_VMLAExecutableDef_start_as_root(builder); + + // Serialize the VM module to bytes directly into a flatbuffer. IREE::VM::BytecodeTargetOptions bytecodeOptions; auto dataRef = builder.streamUint8Vec([&](raw_ostream &stream) { return succeeded(translateModuleToBytecode(targetOp.getInnerModule(), @@ -115,7 +117,6 @@ class VMLATargetBackend final : public TargetBackend { // Pack the executable definition and get the bytes with the proper header. // The header is used to verify the contents at runtime. - iree_VMLAExecutableDef_start_as_root(builder); iree_VMLAExecutableDef_bytecode_module_add(builder, dataRef); iree_VMLAExecutableDef_end_as_root(builder); diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp index de7c23794aca..e1a556e25d40 100644 --- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp +++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp @@ -128,9 +128,11 @@ class VulkanSPIRVTargetBackend : public SPIRVTargetBackend { ModuleOp innerModuleOp = targetOp.getInnerModule(); auto spvModuleOp = *innerModuleOp.getOps().begin(); + FlatbufferBuilder builder; + iree_SpirVExecutableDef_start_as_root(builder); + // Serialize the spirv::ModuleOp into the binary that we will embed in the // final flatbuffer. - FlatbufferBuilder builder; SmallVector spvBinary; if (failed(spirv::serialize(spvModuleOp, spvBinary)) || spvBinary.empty()) { return targetOp.emitError() << "failed to serialize spv.module"; @@ -157,7 +159,6 @@ class VulkanSPIRVTargetBackend : public SPIRVTargetBackend { } auto entryPointsRef = builder.createStringVec(entryPointNames); - iree_SpirVExecutableDef_start_as_root(builder); iree_SpirVExecutableDef_entry_points_add(builder, entryPointsRef); iree_SpirVExecutableDef_code_add(builder, spvCodeRef); iree_SpirVExecutableDef_end_as_root(builder); diff --git a/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp b/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp index f5dac0305c55..a495ebfd807c 100644 --- a/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp +++ b/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp @@ -117,7 +117,8 @@ static Value createStringTableValue(Location loc, StringAttr attrValue, return rewriter.create( loc, IREE::VM::RefType::get(IREE::ByteBufferType::get(rewriter.getContext())), - rewriter.getStringAttr(safeIdentifier), utf8Bytes); + rewriter.getStringAttr(safeIdentifier), utf8Bytes, + /*alignment=*/rewriter.getI64IntegerAttr(1)); } size_t getSegmentSpanSize(Type spanType) { diff --git a/iree/compiler/Dialect/VM/IR/VMOps.td b/iree/compiler/Dialect/VM/IR/VMOps.td index 82a344febe7a..408f4e646e09 100644 --- a/iree/compiler/Dialect/VM/IR/VMOps.td +++ b/iree/compiler/Dialect/VM/IR/VMOps.td @@ -754,18 +754,26 @@ def VM_RodataOp : VM_Op<"rodata", [ value leaves the module. For example, returning rodata from an exported function must keep the data (possibly backed by mmap) valid for its entire lifetime. + + By default all rodata will be aligned in the final module output at a + 16-byte granularity. An optional alignment can be specified to override the + default for cases where larger or smaller alignments are needed. }]; let arguments = (ins StrAttr:$sym_name, ElementsAttr:$value, + OptionalAttr:$alignment, OptionalAttr:$ordinal ); let skipDefaultBuilders = 1; let builders = [ - OpBuilder<(ins "StringRef":$name, "ElementsAttr":$value, - CArg<"ArrayRef", "{}">:$attrs)>, + OpBuilder<(ins + "StringRef":$name, + "ElementsAttr":$value, + CArg<"ArrayRef", "{}">:$attrs + )>, ]; } @@ -810,12 +818,14 @@ def VM_RodataInlineOp : VM_PureOp<"rodata.inline", [ ]> { let summary = [{inlined constant rodata}]; let description = [{ - vm.rodata that can be embedded inline in functions. + vm.rodata that can be embedded inline in functions. See vm.rodata for more + information. }]; let arguments = (ins OptionalAttr:$name, - ElementsAttr:$value + ElementsAttr:$value, + OptionalAttr:$alignment ); let results = (outs diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp index 821ba43ec517..7c059699c4fa 100644 --- a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp +++ b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp @@ -276,6 +276,11 @@ static iree_vm_FunctionSignatureDef_ref_t makeInternalFunctionSignatureDef( static LogicalResult buildFlatBufferModule(BytecodeTargetOptions targetOptions, IREE::VM::ModuleOp moduleOp, FlatbufferBuilder &fbb) { + // Start the buffer so that we can begin recording data prior to the root + // table (which we do at the very end). This does not change the layout of the + // file and is only used to prime the flatcc builder. + iree_vm_BytecodeModuleDef_start_as_root(fbb); + SymbolTable symbolTable(moduleOp); if (!moduleOp.ordinal_counts().hasValue()) { return moduleOp.emitError() << "ordinal_counts attribute not found. The " @@ -316,9 +321,20 @@ static LogicalResult buildFlatBufferModule(BytecodeTargetOptions targetOptions, // layout planning by preserving the order in the IR is useful. SmallVector rodataContentRefs; rodataContentRefs.reserve(rodataOps.size()); + + // All constants are defaulted to 16-byte aligned as that is the maximum + // (reasonable) alignment of all data types on all platforms. This can be + // overridden by creators of the rodata with the `alignment` attribute. + static constexpr int kDefaultRodataAlignment = 16; + for (auto rodataOp : llvm::reverse(rodataOps)) { + size_t alignment = + rodataOp.alignment() + ? static_cast(rodataOp.alignment().getValue()) + : 0; + if (alignment == 0) alignment = kDefaultRodataAlignment; auto rodataRef = - serializeConstant(rodataOp.getLoc(), rodataOp.value(), fbb); + serializeConstant(rodataOp.getLoc(), rodataOp.value(), alignment, fbb); if (!rodataRef) { return rodataOp.emitOpError() << "failed to encode"; } @@ -462,7 +478,6 @@ static LogicalResult buildFlatBufferModule(BytecodeTargetOptions targetOptions, auto moduleNameRef = fbb.createString( moduleOp.sym_name().empty() ? "module" : moduleOp.sym_name()); - iree_vm_BytecodeModuleDef_start_as_root(fbb); iree_vm_BytecodeModuleDef_name_add(fbb, moduleNameRef); iree_vm_BytecodeModuleDef_types_add(fbb, typesRef); iree_vm_BytecodeModuleDef_imported_functions_add(fbb, importFuncsRef); diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp b/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp index b33908bdfe6f..91a5fb26170a 100644 --- a/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp +++ b/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp @@ -26,11 +26,11 @@ namespace VM { // TODO(benvanik): switch to LLVM's BinaryStreamWriter to handle endianness. static flatbuffers_uint8_vec_ref_t serializeConstantI8Array( - DenseIntElementsAttr attr, FlatbufferBuilder &fbb) { + DenseIntElementsAttr attr, size_t alignment, FlatbufferBuilder &fbb) { // vm.rodata and other very large constants end up as this; since i8 is i8 // everywhere (endianness doesn't matter when you have one byte :) we can // directly access the data and memcpy. - flatbuffers_uint8_vec_start(fbb); + flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1)); uint8_t *bytePtr = flatbuffers_uint8_vec_extend(fbb, attr.getNumElements() * sizeof(int8_t)); if (attr.isSplat()) { @@ -47,8 +47,8 @@ static flatbuffers_uint8_vec_ref_t serializeConstantI8Array( } static flatbuffers_uint8_vec_ref_t serializeConstantI16Array( - DenseIntElementsAttr attr, FlatbufferBuilder &fbb) { - flatbuffers_uint8_vec_start(fbb); + DenseIntElementsAttr attr, size_t alignment, FlatbufferBuilder &fbb) { + flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1)); uint8_t *bytePtr = flatbuffers_uint8_vec_extend( fbb, attr.getNumElements() * sizeof(int16_t)); uint16_t *nativePtr = reinterpret_cast(bytePtr); @@ -59,8 +59,8 @@ static flatbuffers_uint8_vec_ref_t serializeConstantI16Array( } static flatbuffers_uint8_vec_ref_t serializeConstantI32Array( - DenseIntElementsAttr attr, FlatbufferBuilder &fbb) { - flatbuffers_uint8_vec_start(fbb); + DenseIntElementsAttr attr, size_t alignment, FlatbufferBuilder &fbb) { + flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1)); uint8_t *bytePtr = flatbuffers_uint8_vec_extend( fbb, attr.getNumElements() * sizeof(int32_t)); uint32_t *nativePtr = reinterpret_cast(bytePtr); @@ -71,8 +71,8 @@ static flatbuffers_uint8_vec_ref_t serializeConstantI32Array( } static flatbuffers_uint8_vec_ref_t serializeConstantI64Array( - DenseIntElementsAttr attr, FlatbufferBuilder &fbb) { - flatbuffers_uint8_vec_start(fbb); + DenseIntElementsAttr attr, size_t alignment, FlatbufferBuilder &fbb) { + flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1)); uint8_t *bytePtr = flatbuffers_uint8_vec_extend( fbb, attr.getNumElements() * sizeof(int64_t)); uint64_t *nativePtr = reinterpret_cast(bytePtr); @@ -83,8 +83,8 @@ static flatbuffers_uint8_vec_ref_t serializeConstantI64Array( } static flatbuffers_uint8_vec_ref_t serializeConstantF32Array( - DenseFPElementsAttr attr, FlatbufferBuilder &fbb) { - flatbuffers_uint8_vec_start(fbb); + DenseFPElementsAttr attr, size_t alignment, FlatbufferBuilder &fbb) { + flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1)); uint8_t *bytePtr = flatbuffers_uint8_vec_extend(fbb, attr.getNumElements() * sizeof(float)); float *nativePtr = reinterpret_cast(bytePtr); @@ -95,8 +95,8 @@ static flatbuffers_uint8_vec_ref_t serializeConstantF32Array( } static flatbuffers_uint8_vec_ref_t serializeConstantF64Array( - DenseFPElementsAttr attr, FlatbufferBuilder &fbb) { - flatbuffers_uint8_vec_start(fbb); + DenseFPElementsAttr attr, size_t alignment, FlatbufferBuilder &fbb) { + flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1)); uint8_t *bytePtr = flatbuffers_uint8_vec_extend(fbb, attr.getNumElements() * sizeof(double)); double *nativePtr = reinterpret_cast(bytePtr); @@ -107,8 +107,8 @@ static flatbuffers_uint8_vec_ref_t serializeConstantF64Array( } static flatbuffers_uint8_vec_ref_t serializeConstantF16Array( - DenseFPElementsAttr attr, FlatbufferBuilder &fbb) { - flatbuffers_uint8_vec_start(fbb); + DenseFPElementsAttr attr, size_t alignment, FlatbufferBuilder &fbb) { + flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1)); uint8_t *bytePtr = flatbuffers_uint8_vec_extend( fbb, attr.getNumElements() * sizeof(uint16_t)); uint16_t *nativePtr = reinterpret_cast(bytePtr); @@ -121,17 +121,18 @@ static flatbuffers_uint8_vec_ref_t serializeConstantF16Array( flatbuffers_uint8_vec_ref_t serializeConstant(Location loc, ElementsAttr elementsAttr, + size_t alignment, FlatbufferBuilder &fbb) { if (auto attr = elementsAttr.dyn_cast()) { switch (attr.getType().getElementTypeBitWidth()) { case 8: - return serializeConstantI8Array(attr, fbb); + return serializeConstantI8Array(attr, alignment, fbb); case 16: - return serializeConstantI16Array(attr, fbb); + return serializeConstantI16Array(attr, alignment, fbb); case 32: - return serializeConstantI32Array(attr, fbb); + return serializeConstantI32Array(attr, alignment, fbb); case 64: - return serializeConstantI64Array(attr, fbb); + return serializeConstantI64Array(attr, alignment, fbb); default: emitError(loc) << "unhandled element bitwidth " << attr.getType().getElementTypeBitWidth(); @@ -140,11 +141,11 @@ flatbuffers_uint8_vec_ref_t serializeConstant(Location loc, } else if (auto attr = elementsAttr.dyn_cast()) { switch (attr.getType().getElementTypeBitWidth()) { case 16: - return serializeConstantF16Array(attr, fbb); + return serializeConstantF16Array(attr, alignment, fbb); case 32: - return serializeConstantF32Array(attr, fbb); + return serializeConstantF32Array(attr, alignment, fbb); case 64: - return serializeConstantF64Array(attr, fbb); + return serializeConstantF64Array(attr, alignment, fbb); default: emitError(loc) << "unhandled element bitwidth " << attr.getType().getElementTypeBitWidth(); diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.h b/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.h index 56471a633fff..94dca8186ae7 100644 --- a/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.h +++ b/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.h @@ -28,6 +28,7 @@ namespace VM { // Serializes a constant attribute to the FlatBuffer as a binary blob. flatbuffers_uint8_vec_ref_t serializeConstant(Location loc, ElementsAttr elementsAttr, + size_t alignment, FlatbufferBuilder &fbb); } // namespace VM diff --git a/iree/compiler/Dialect/VM/Transforms/HoistInlinedRodata.cpp b/iree/compiler/Dialect/VM/Transforms/HoistInlinedRodata.cpp index 240a507f5004..7aa741dd3bd0 100644 --- a/iree/compiler/Dialect/VM/Transforms/HoistInlinedRodata.cpp +++ b/iree/compiler/Dialect/VM/Transforms/HoistInlinedRodata.cpp @@ -60,6 +60,9 @@ class HoistInlinedRodataPass auto rodataOp = OpBuilder(moduleOp.getContext()) .create(inlineOp.getLoc(), name, inlineOp.value()); + if (inlineOp.alignmentAttr()) { + rodataOp.alignmentAttr(inlineOp.alignmentAttr()); + } moduleSymbolTable.insert(rodataOp, moduleBuilder.getInsertionPoint()); rodataOp.setPrivate(); replaceInlineOpWithRodataRef(inlineOp, rodataOp); diff --git a/iree/compiler/Utils/FlatbufferUtils.cpp b/iree/compiler/Utils/FlatbufferUtils.cpp index e98167b75d40..90d03be0f1d7 100644 --- a/iree/compiler/Utils/FlatbufferUtils.cpp +++ b/iree/compiler/Utils/FlatbufferUtils.cpp @@ -44,8 +44,8 @@ FlatbufferBuilder::FlatbufferBuilder() { flatcc_builder_init(&builder); } FlatbufferBuilder::~FlatbufferBuilder() { flatcc_builder_clear(&builder); } flatbuffers_uint8_vec_ref_t FlatbufferBuilder::streamUint8Vec( - std::function fn) { - flatbuffers_uint8_vec_start(*this); + std::function fn, size_t alignment) { + flatcc_builder_start_vector(*this, 1, alignment, FLATBUFFERS_COUNT_MAX(1)); raw_flatbuffer_uint8_vec_ostream stream(*this); if (!fn(stream)) { return 0; diff --git a/iree/compiler/Utils/FlatbufferUtils.h b/iree/compiler/Utils/FlatbufferUtils.h index 524cf63f7725..783bffb910ed 100644 --- a/iree/compiler/Utils/FlatbufferUtils.h +++ b/iree/compiler/Utils/FlatbufferUtils.h @@ -108,7 +108,7 @@ class FlatbufferBuilder { // my_type_uint8_vec_field_add(builder, ref); // use vec reference // ... flatbuffers_uint8_vec_ref_t streamUint8Vec( - std::function fn); + std::function fn, size_t alignment = 16); // Captures the current contents of the flatbuffer builder and returns them // as a shaped `vector` dense attr. The builder is left unmodified.