Skip to content

Commit

Permalink
Specify an alignment on vm.rodata and use it in flatbuffers. (#5494)
Browse files Browse the repository at this point in the history
This required a minor tweak to start_as_root the flatbuffer root tables
prior to recording any data. This was discovered thanks to the gracious
help (and patience) of of the flatcc author:
dvidelabs/flatcc#179

With this all of our binary blobs embedded into the flatbuffers are now
16-byte aligned with the option to tweak it further via a vm.rodata
attribute. We don't need utf8 strings to be 16-byte aligned, for example.
  • Loading branch information
benvanik authored Apr 16, 2021
1 parent 6d0aae7 commit 3c8880e
Show file tree
Hide file tree
Showing 13 changed files with 76 additions and 41 deletions.
7 changes: 4 additions & 3 deletions iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,12 @@ class CUDATargetBackend final : public TargetBackend {

llvmModule->setDataLayout(targetMachine->createDataLayout());

std::string targetISA = translateModuleToISA(*llvmModule, *targetMachine);
FlatbufferBuilder builder;
iree_CUDAExecutableDef_start_as_root(builder);

// Serialize cuda kernel into the binary that we will embed in the
// final flatbuffer.
FlatbufferBuilder builder;
std::string targetISA = translateModuleToISA(*llvmModule, *targetMachine);
auto ptxCudeRef = flatbuffers_uint8_vec_create(
builder, reinterpret_cast<const uint8_t *>(targetISA.c_str()),
targetISA.size());
Expand All @@ -168,7 +170,6 @@ class CUDATargetBackend final : public TargetBackend {
}
auto blockSizesRef = iree_CUDABlockSizeDef_vec_end(builder);

iree_CUDAExecutableDef_start_as_root(builder);
iree_CUDAExecutableDef_entry_points_add(builder, entryPointsRef);
iree_CUDAExecutableDef_block_sizes_add(builder, blockSizesRef);
iree_CUDAExecutableDef_ptx_image_add(builder, ptxCudeRef);
Expand Down
5 changes: 3 additions & 2 deletions iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,11 @@ class LLVMAOTTargetBackend final : public TargetBackend {
linkArtifacts.keepAllFiles();
}

FlatbufferBuilder builder;
iree_DyLibExecutableDef_start_as_root(builder);

// Embed debug symbols at the end of the flatbuffer by adding first in the
// bottoms-up builder.
FlatbufferBuilder builder;
flatbuffers_uint8_vec_ref_t debugDatabaseRef = 0;
flatbuffers_string_ref_t debugDatabaseFilenameRef = 0;
if (options_.debugSymbols && linkArtifacts.debugFile.outputFile) {
Expand All @@ -328,7 +330,6 @@ class LLVMAOTTargetBackend final : public TargetBackend {
<< linkArtifacts.libraryFile.path;
}

iree_DyLibExecutableDef_start_as_root(builder);
iree_DyLibExecutableDef_library_embedded_add(builder, libraryEmbeddedRef);
iree_DyLibExecutableDef_debug_database_filename_add(
builder, debugDatabaseFilenameRef);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class MetalSPIRVTargetBackend : public SPIRVTargetBackend {

// 4. Pack the MTLLibrary and metadata into a flatbuffer.
FlatbufferBuilder builder;
iree_MetalExecutableDef_start_as_root(builder);

auto shaderSourcesRef = builder.createStringVec(llvm::map_range(
mslShaders, [&](const MetalShader &shader) { return shader.source; }));
Expand All @@ -135,7 +136,6 @@ class MetalSPIRVTargetBackend : public SPIRVTargetBackend {

auto entryPointNamesRef = builder.createStringVec(entryPointNames);

iree_MetalExecutableDef_start_as_root(builder);
iree_MetalExecutableDef_entry_points_add(builder, entryPointNamesRef);
iree_MetalExecutableDef_threadgroup_sizes_add(builder, threadgroupSizesRef);
iree_MetalExecutableDef_shader_sources_add(builder, shaderSourcesRef);
Expand Down
5 changes: 3 additions & 2 deletions iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,10 @@ class VMLATargetBackend final : public TargetBackend {

LogicalResult serializeExecutable(IREE::HAL::ExecutableTargetOp targetOp,
OpBuilder &executableBuilder) override {
// Serialize the VM module to bytes directly into a flatbuffer.
FlatbufferBuilder builder;
iree_VMLAExecutableDef_start_as_root(builder);

// Serialize the VM module to bytes directly into a flatbuffer.
IREE::VM::BytecodeTargetOptions bytecodeOptions;
auto dataRef = builder.streamUint8Vec([&](raw_ostream &stream) {
return succeeded(translateModuleToBytecode(targetOp.getInnerModule(),
Expand All @@ -115,7 +117,6 @@ class VMLATargetBackend final : public TargetBackend {

// Pack the executable definition and get the bytes with the proper header.
// The header is used to verify the contents at runtime.
iree_VMLAExecutableDef_start_as_root(builder);
iree_VMLAExecutableDef_bytecode_module_add(builder, dataRef);
iree_VMLAExecutableDef_end_as_root(builder);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,11 @@ class VulkanSPIRVTargetBackend : public SPIRVTargetBackend {
ModuleOp innerModuleOp = targetOp.getInnerModule();
auto spvModuleOp = *innerModuleOp.getOps<spirv::ModuleOp>().begin();

FlatbufferBuilder builder;
iree_SpirVExecutableDef_start_as_root(builder);

// Serialize the spirv::ModuleOp into the binary that we will embed in the
// final flatbuffer.
FlatbufferBuilder builder;
SmallVector<uint32_t, 256> spvBinary;
if (failed(spirv::serialize(spvModuleOp, spvBinary)) || spvBinary.empty()) {
return targetOp.emitError() << "failed to serialize spv.module";
Expand All @@ -157,7 +159,6 @@ class VulkanSPIRVTargetBackend : public SPIRVTargetBackend {
}
auto entryPointsRef = builder.createStringVec(entryPointNames);

iree_SpirVExecutableDef_start_as_root(builder);
iree_SpirVExecutableDef_entry_points_add(builder, entryPointsRef);
iree_SpirVExecutableDef_code_add(builder, spvCodeRef);
iree_SpirVExecutableDef_end_as_root(builder);
Expand Down
3 changes: 2 additions & 1 deletion iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ static Value createStringTableValue(Location loc, StringAttr attrValue,
return rewriter.create<IREE::VM::RodataInlineOp>(
loc,
IREE::VM::RefType::get(IREE::ByteBufferType::get(rewriter.getContext())),
rewriter.getStringAttr(safeIdentifier), utf8Bytes);
rewriter.getStringAttr(safeIdentifier), utf8Bytes,
/*alignment=*/rewriter.getI64IntegerAttr(1));
}

size_t getSegmentSpanSize(Type spanType) {
Expand Down
18 changes: 14 additions & 4 deletions iree/compiler/Dialect/VM/IR/VMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -754,18 +754,26 @@ def VM_RodataOp : VM_Op<"rodata", [
value leaves the module. For example, returning rodata from an exported
function must keep the data (possibly backed by mmap) valid for its entire
lifetime.

By default all rodata will be aligned in the final module output at a
16-byte granularity. An optional alignment can be specified to override the
default for cases where larger or smaller alignments are needed.
}];

let arguments = (ins
StrAttr:$sym_name,
ElementsAttr:$value,
OptionalAttr<I64Attr>:$alignment,
OptionalAttr<VM_Ordinal>:$ordinal
);

let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins "StringRef":$name, "ElementsAttr":$value,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
OpBuilder<(ins
"StringRef":$name,
"ElementsAttr":$value,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs
)>,
];
}

Expand Down Expand Up @@ -810,12 +818,14 @@ def VM_RodataInlineOp : VM_PureOp<"rodata.inline", [
]> {
let summary = [{inlined constant rodata}];
let description = [{
vm.rodata that can be embedded inline in functions.
vm.rodata that can be embedded inline in functions. See vm.rodata for more
information.
}];

let arguments = (ins
OptionalAttr<StrAttr>:$name,
ElementsAttr:$value
ElementsAttr:$value,
OptionalAttr<I64Attr>:$alignment
);

let results = (outs
Expand Down
19 changes: 17 additions & 2 deletions iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,11 @@ static iree_vm_FunctionSignatureDef_ref_t makeInternalFunctionSignatureDef(
static LogicalResult buildFlatBufferModule(BytecodeTargetOptions targetOptions,
IREE::VM::ModuleOp moduleOp,
FlatbufferBuilder &fbb) {
// Start the buffer so that we can begin recording data prior to the root
// table (which we do at the very end). This does not change the layout of the
// file and is only used to prime the flatcc builder.
iree_vm_BytecodeModuleDef_start_as_root(fbb);

SymbolTable symbolTable(moduleOp);
if (!moduleOp.ordinal_counts().hasValue()) {
return moduleOp.emitError() << "ordinal_counts attribute not found. The "
Expand Down Expand Up @@ -315,9 +320,20 @@ static LogicalResult buildFlatBufferModule(BytecodeTargetOptions targetOptions,
// layout planning by preserving the order in the IR is useful.
SmallVector<flatbuffers_uint8_vec_ref_t, 8> rodataContentRefs;
rodataContentRefs.reserve(rodataOps.size());

// All constants are defaulted to 16-byte aligned as that is the maximum
// (reasonable) alignment of all data types on all platforms. This can be
// overridden by creators of the rodata with the `alignment` attribute.
static constexpr int kDefaultRodataAlignment = 16;

for (auto rodataOp : llvm::reverse(rodataOps)) {
size_t alignment =
rodataOp.alignment()
? static_cast<size_t>(rodataOp.alignment().getValue())
: 0;
if (alignment == 0) alignment = kDefaultRodataAlignment;
auto rodataRef =
serializeConstant(rodataOp.getLoc(), rodataOp.value(), fbb);
serializeConstant(rodataOp.getLoc(), rodataOp.value(), alignment, fbb);
if (!rodataRef) {
return rodataOp.emitOpError() << "failed to encode";
}
Expand Down Expand Up @@ -461,7 +477,6 @@ static LogicalResult buildFlatBufferModule(BytecodeTargetOptions targetOptions,
auto moduleNameRef = fbb.createString(
moduleOp.sym_name().empty() ? "module" : moduleOp.sym_name());

iree_vm_BytecodeModuleDef_start_as_root(fbb);
iree_vm_BytecodeModuleDef_name_add(fbb, moduleNameRef);
iree_vm_BytecodeModuleDef_types_add(fbb, typesRef);
iree_vm_BytecodeModuleDef_imported_functions_add(fbb, importFuncsRef);
Expand Down
43 changes: 22 additions & 21 deletions iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ namespace VM {
// TODO(benvanik): switch to LLVM's BinaryStreamWriter to handle endianness.

static flatbuffers_uint8_vec_ref_t serializeConstantI8Array(
DenseIntElementsAttr attr, FlatbufferBuilder &fbb) {
DenseIntElementsAttr attr, size_t alignment, FlatbufferBuilder &fbb) {
// vm.rodata and other very large constants end up as this; since i8 is i8
// everywhere (endianness doesn't matter when you have one byte :) we can
// directly access the data and memcpy.
flatbuffers_uint8_vec_start(fbb);
flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1));
uint8_t *bytePtr =
flatbuffers_uint8_vec_extend(fbb, attr.getNumElements() * sizeof(int8_t));
if (attr.isSplat()) {
Expand All @@ -47,8 +47,8 @@ static flatbuffers_uint8_vec_ref_t serializeConstantI8Array(
}

static flatbuffers_uint8_vec_ref_t serializeConstantI16Array(
DenseIntElementsAttr attr, FlatbufferBuilder &fbb) {
flatbuffers_uint8_vec_start(fbb);
DenseIntElementsAttr attr, size_t alignment, FlatbufferBuilder &fbb) {
flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1));
uint8_t *bytePtr = flatbuffers_uint8_vec_extend(
fbb, attr.getNumElements() * sizeof(int16_t));
uint16_t *nativePtr = reinterpret_cast<uint16_t *>(bytePtr);
Expand All @@ -59,8 +59,8 @@ static flatbuffers_uint8_vec_ref_t serializeConstantI16Array(
}

static flatbuffers_uint8_vec_ref_t serializeConstantI32Array(
DenseIntElementsAttr attr, FlatbufferBuilder &fbb) {
flatbuffers_uint8_vec_start(fbb);
DenseIntElementsAttr attr, size_t alignment, FlatbufferBuilder &fbb) {
flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1));
uint8_t *bytePtr = flatbuffers_uint8_vec_extend(
fbb, attr.getNumElements() * sizeof(int32_t));
uint32_t *nativePtr = reinterpret_cast<uint32_t *>(bytePtr);
Expand All @@ -71,8 +71,8 @@ static flatbuffers_uint8_vec_ref_t serializeConstantI32Array(
}

static flatbuffers_uint8_vec_ref_t serializeConstantI64Array(
DenseIntElementsAttr attr, FlatbufferBuilder &fbb) {
flatbuffers_uint8_vec_start(fbb);
DenseIntElementsAttr attr, size_t alignment, FlatbufferBuilder &fbb) {
flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1));
uint8_t *bytePtr = flatbuffers_uint8_vec_extend(
fbb, attr.getNumElements() * sizeof(int64_t));
uint64_t *nativePtr = reinterpret_cast<uint64_t *>(bytePtr);
Expand All @@ -83,8 +83,8 @@ static flatbuffers_uint8_vec_ref_t serializeConstantI64Array(
}

static flatbuffers_uint8_vec_ref_t serializeConstantF32Array(
DenseFPElementsAttr attr, FlatbufferBuilder &fbb) {
flatbuffers_uint8_vec_start(fbb);
DenseFPElementsAttr attr, size_t alignment, FlatbufferBuilder &fbb) {
flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1));
uint8_t *bytePtr =
flatbuffers_uint8_vec_extend(fbb, attr.getNumElements() * sizeof(float));
float *nativePtr = reinterpret_cast<float *>(bytePtr);
Expand All @@ -95,8 +95,8 @@ static flatbuffers_uint8_vec_ref_t serializeConstantF32Array(
}

static flatbuffers_uint8_vec_ref_t serializeConstantF64Array(
DenseFPElementsAttr attr, FlatbufferBuilder &fbb) {
flatbuffers_uint8_vec_start(fbb);
DenseFPElementsAttr attr, size_t alignment, FlatbufferBuilder &fbb) {
flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1));
uint8_t *bytePtr =
flatbuffers_uint8_vec_extend(fbb, attr.getNumElements() * sizeof(double));
double *nativePtr = reinterpret_cast<double *>(bytePtr);
Expand All @@ -107,8 +107,8 @@ static flatbuffers_uint8_vec_ref_t serializeConstantF64Array(
}

static flatbuffers_uint8_vec_ref_t serializeConstantF16Array(
DenseFPElementsAttr attr, FlatbufferBuilder &fbb) {
flatbuffers_uint8_vec_start(fbb);
DenseFPElementsAttr attr, size_t alignment, FlatbufferBuilder &fbb) {
flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1));
uint8_t *bytePtr = flatbuffers_uint8_vec_extend(
fbb, attr.getNumElements() * sizeof(uint16_t));
uint16_t *nativePtr = reinterpret_cast<uint16_t *>(bytePtr);
Expand All @@ -121,17 +121,18 @@ static flatbuffers_uint8_vec_ref_t serializeConstantF16Array(

flatbuffers_uint8_vec_ref_t serializeConstant(Location loc,
ElementsAttr elementsAttr,
size_t alignment,
FlatbufferBuilder &fbb) {
if (auto attr = elementsAttr.dyn_cast<DenseIntElementsAttr>()) {
switch (attr.getType().getElementTypeBitWidth()) {
case 8:
return serializeConstantI8Array(attr, fbb);
return serializeConstantI8Array(attr, alignment, fbb);
case 16:
return serializeConstantI16Array(attr, fbb);
return serializeConstantI16Array(attr, alignment, fbb);
case 32:
return serializeConstantI32Array(attr, fbb);
return serializeConstantI32Array(attr, alignment, fbb);
case 64:
return serializeConstantI64Array(attr, fbb);
return serializeConstantI64Array(attr, alignment, fbb);
default:
emitError(loc) << "unhandled element bitwidth "
<< attr.getType().getElementTypeBitWidth();
Expand All @@ -140,11 +141,11 @@ flatbuffers_uint8_vec_ref_t serializeConstant(Location loc,
} else if (auto attr = elementsAttr.dyn_cast<DenseFPElementsAttr>()) {
switch (attr.getType().getElementTypeBitWidth()) {
case 16:
return serializeConstantF16Array(attr, fbb);
return serializeConstantF16Array(attr, alignment, fbb);
case 32:
return serializeConstantF32Array(attr, fbb);
return serializeConstantF32Array(attr, alignment, fbb);
case 64:
return serializeConstantF64Array(attr, fbb);
return serializeConstantF64Array(attr, alignment, fbb);
default:
emitError(loc) << "unhandled element bitwidth "
<< attr.getType().getElementTypeBitWidth();
Expand Down
1 change: 1 addition & 0 deletions iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ namespace VM {
// Serializes a constant attribute to the FlatBuffer as a binary blob.
flatbuffers_uint8_vec_ref_t serializeConstant(Location loc,
ElementsAttr elementsAttr,
size_t alignment,
FlatbufferBuilder &fbb);

} // namespace VM
Expand Down
3 changes: 3 additions & 0 deletions iree/compiler/Dialect/VM/Transforms/HoistInlinedRodata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ class HoistInlinedRodataPass
auto rodataOp = OpBuilder(moduleOp.getContext())
.create<IREE::VM::RodataOp>(inlineOp.getLoc(), name,
inlineOp.value());
if (inlineOp.alignmentAttr()) {
rodataOp.alignmentAttr(inlineOp.alignmentAttr());
}
moduleSymbolTable.insert(rodataOp, moduleBuilder.getInsertionPoint());
rodataOp.setPrivate();
replaceInlineOpWithRodataRef(inlineOp, rodataOp);
Expand Down
4 changes: 2 additions & 2 deletions iree/compiler/Utils/FlatbufferUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ FlatbufferBuilder::FlatbufferBuilder() { flatcc_builder_init(&builder); }
FlatbufferBuilder::~FlatbufferBuilder() { flatcc_builder_clear(&builder); }

flatbuffers_uint8_vec_ref_t FlatbufferBuilder::streamUint8Vec(
std::function<bool(raw_ostream &stream)> fn) {
flatbuffers_uint8_vec_start(*this);
std::function<bool(raw_ostream &stream)> fn, size_t alignment) {
flatcc_builder_start_vector(*this, 1, alignment, FLATBUFFERS_COUNT_MAX(1));
raw_flatbuffer_uint8_vec_ostream stream(*this);
if (!fn(stream)) {
return 0;
Expand Down
2 changes: 1 addition & 1 deletion iree/compiler/Utils/FlatbufferUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class FlatbufferBuilder {
// my_type_uint8_vec_field_add(builder, ref); // use vec reference
// ...
flatbuffers_uint8_vec_ref_t streamUint8Vec(
std::function<bool(raw_ostream &stream)> fn);
std::function<bool(raw_ostream &stream)> fn, size_t alignment = 16);

// Captures the current contents of the flatbuffer builder and returns them
// as a shaped `vector<SIZExi8>` dense attr. The builder is left unmodified.
Expand Down

0 comments on commit 3c8880e

Please sign in to comment.