Skip to content

WIP: more enum layout optimizations #101819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock
Original file line number Diff line number Diff line change
Expand Up @@ -4249,6 +4249,7 @@ dependencies = [
"rustc_target",
"rustc_trait_selection",
"rustc_type_ir",
"smallvec",
"tracing",
]

Expand Down
6 changes: 4 additions & 2 deletions compiler/rustc_ast/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3065,6 +3065,8 @@ mod size_asserts {
static_assert_size!(PathSegment, 24);
static_assert_size!(Stmt, 32);
static_assert_size!(StmtKind, 16);
static_assert_size!(Ty, 96);
static_assert_size!(TyKind, 72);
#[cfg(not(bootstrap))]
static_assert_size!(Ty, 88);
#[cfg(not(bootstrap))]
static_assert_size!(TyKind, 64);
}
2 changes: 2 additions & 0 deletions compiler/rustc_attr/src/builtin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,7 @@ pub enum ReprAttr {
ReprSimd,
ReprTransparent,
ReprAlign(u32),
ReprFlag,
}

#[derive(Eq, PartialEq, Debug, Copy, Clone)]
Expand Down Expand Up @@ -998,6 +999,7 @@ pub fn parse_repr_attr(sess: &Session, attr: &Attribute) -> Vec<ReprAttr> {
recognised = true;
None
}
sym::flag => Some(ReprFlag),
name => int_type_of_word(name).map(ReprInt),
};

Expand Down
10 changes: 2 additions & 8 deletions compiler/rustc_codegen_cranelift/src/abi/comments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,8 @@ pub(super) fn add_local_place_comments<'tcx>(
return;
}
let TyAndLayout { ty, layout } = place.layout();
let rustc_target::abi::LayoutS {
size,
align,
abi: _,
variants: _,
fields: _,
largest_niche: _,
} = layout.0.0;
let rustc_target::abi::LayoutS { size, align, abi: _, variants: _, fields: _, niches: _ } =
layout.0.0;

let (kind, extra) = match *place.inner() {
CPlaceInner::Var(place_local, var) => {
Expand Down
5 changes: 3 additions & 2 deletions compiler/rustc_codegen_cranelift/src/discriminant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ pub(crate) fn codegen_set_discriminant<'tcx>(
Variants::Multiple {
tag: _,
tag_field,
tag_encoding: TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start },
tag_encoding:
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start, .. },
variants: _,
} => {
if variant_index != untagged_variant {
Expand Down Expand Up @@ -113,7 +114,7 @@ pub(crate) fn codegen_get_discriminant<'tcx>(
let res = CValue::by_val(val, dest_layout);
dest.write_cvalue(fx, res);
}
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start } => {
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start, .. } => {
// Rebase from niche values to discriminants, and check
// whether the result is in range for the niche variants.

Expand Down
50 changes: 26 additions & 24 deletions compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,32 +430,34 @@ fn compute_discriminant_value<'ll, 'tcx>(
enum_type_and_layout.ty.discriminant_for_variant(cx.tcx, variant_index).unwrap().val,
),
&Variants::Multiple {
tag_encoding: TagEncoding::Niche { ref niche_variants, niche_start, untagged_variant },
tag,
//tag_encoding: TagEncoding::Niche { ref niche_variants, niche_start, untagged_variant, .. },
//tag,
..
} => {
if variant_index == untagged_variant {
let valid_range = enum_type_and_layout
.for_variant(cx, variant_index)
.largest_niche
.as_ref()
.unwrap()
.valid_range;

let min = valid_range.start.min(valid_range.end);
let min = tag.size(cx).truncate(min);

let max = valid_range.start.max(valid_range.end);
let max = tag.size(cx).truncate(max);

DiscrResult::Range(min, max)
} else {
let value = (variant_index.as_u32() as u128)
.wrapping_sub(niche_variants.start().as_u32() as u128)
.wrapping_add(niche_start);
let value = tag.size(cx).truncate(value);
DiscrResult::Value(value)
}
// YYY
DiscrResult::Range(0, 1)
//if variant_index == untagged_variant {
// let valid_range = enum_type_and_layout
// .for_variant(cx, variant_index)
// .largest_niche
// .as_ref()
// .unwrap()
// .valid_range;

// let min = valid_range.start.min(valid_range.end);
// let min = tag.size(cx).truncate(min);

// let max = valid_range.start.max(valid_range.end);
// let max = tag.size(cx).truncate(max);

// DiscrResult::Range(min, max)
//} else {
// let value = (variant_index.as_u32() as u128)
// .wrapping_sub(niche_variants.start().as_u32() as u128)
// .wrapping_add(niche_start);
// let value = tag.size(cx).truncate(value);
// DiscrResult::Value(value)
//}
}
}
}
111 changes: 70 additions & 41 deletions compiler/rustc_codegen_ssa/src/mir/place.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::operand::OperandValue;
use super::{FunctionCx, LocalRef};

use crate::common::IntPredicate;
use crate::common::{IntPredicate, TypeKind};
use crate::glue;
use crate::traits::*;

Expand Down Expand Up @@ -227,13 +227,13 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
}
};

// Read the tag/niche-encoded discriminant from memory.
let tag = self.project_field(bx, tag_field);
let tag = bx.load_operand(tag);
let tag_place = self.project_field(bx, tag_field);

// Decode the discriminant (specifically if it's niche-encoded).
match *tag_encoding {
TagEncoding::Direct => {
// Read the tag from memory.
let tag = bx.load_operand(tag_place);
let signed = match tag_scalar.primitive() {
// We use `i1` for bytes that are always `0` or `1`,
// e.g., `#[repr(i8)] enum E { A, B }`, but we can't
Expand All @@ -244,11 +244,30 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
};
bx.intcast(tag.immediate(), cast_to, signed)
}
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start } => {
// Rebase from niche values to discriminants, and check
// whether the result is in range for the niche variants.
let niche_llty = bx.cx().immediate_backend_type(tag.layout);
let tag = tag.immediate();
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start, ref flag, } => {
let read = |bx: &mut Bx, place: Self| -> (V, <Bx as BackendTypes>::Type) {
let ty = bx.cx().immediate_backend_type(place.layout);
let op = bx.load_operand(place);
let val = op.immediate();
if bx.cx().type_kind(ty) == TypeKind::Pointer {
let new_ty = bx.cx().type_isize();
let new_val = bx.ptrtoint(val, new_ty);
(new_val, new_ty)
} else {
(val, ty)
}
};

let (tag, niche_llty) = read(bx, tag_place);

let (untagged_in_niche, flag_eq_magic_value_opt) = if let Some(flag) = flag {
let flag_place = self.project_field(bx, flag.field);
let (flag_imm, flag_llty) = read(bx, flag_place);
let magic_value = bx.cx().const_uint_big(flag_llty, flag.magic_value);
(flag.untagged_in_niche, Some(bx.icmp(IntPredicate::IntEQ, flag_imm, magic_value)))
} else {
(true, None)
};

// We first compute the "relative discriminant" (wrt `niche_variants`),
// that is, if `n = niche_variants.end() - niche_variants.start()`,
Expand All @@ -259,23 +278,8 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
// and check that it is in the range `niche_variants`, because
// that might not fit in the same type, on top of needing an extra
// comparison (see also the comment on `let niche_discr`).
let relative_discr = if niche_start == 0 {
// Avoid subtracting `0`, which wouldn't work for pointers.
// FIXME(eddyb) check the actual primitive type here.
tag
} else {
bx.sub(tag, bx.cx().const_uint_big(niche_llty, niche_start))
};
let relative_discr = bx.sub(tag, bx.cx().const_uint_big(niche_llty, niche_start));
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
let is_niche = if relative_max == 0 {
// Avoid calling `const_uint`, which wouldn't work for pointers.
// Also use canonical == 0 instead of non-canonical u<= 0.
// FIXME(eddyb) check the actual primitive type here.
bx.icmp(IntPredicate::IntEQ, relative_discr, bx.cx().const_null(niche_llty))
} else {
let relative_max = bx.cx().const_uint(niche_llty, relative_max as u64);
bx.icmp(IntPredicate::IntULE, relative_discr, relative_max)
};

// NOTE(eddyb) this addition needs to be performed on the final
// type, in case the niche itself can't represent all variant
Expand All @@ -285,7 +289,7 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
// In other words, `niche_variants.end - niche_variants.start`
// is representable in the niche, but `niche_variants.end`
// might not be, in extreme cases.
let niche_discr = {
let potential_niche_discr = {
let relative_discr = if relative_max == 0 {
// HACK(eddyb) since we have only one niche, we know which
// one it is, and we can avoid having a dynamic value here.
Expand All @@ -299,11 +303,29 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
)
};

bx.select(
is_niche,
niche_discr,
bx.cx().const_uint(cast_to, untagged_variant.as_u32() as u64),
)
let untagged_discr = bx.cx().const_uint(cast_to, untagged_variant.as_u32() as u64);

let niche_discr = if untagged_in_niche {
let relative_max_const = bx.cx().const_uint(niche_llty, relative_max as u64);
let is_niche = bx.icmp(IntPredicate::IntULE, relative_discr, relative_max_const);
bx.select(
is_niche,
potential_niche_discr,
untagged_discr,
)
} else {
potential_niche_discr
};

if let Some(flag_eq_magic_value) = flag_eq_magic_value_opt {
bx.select(
flag_eq_magic_value,
niche_discr,
untagged_discr,
)
} else {
niche_discr
}
}
}
}
Expand Down Expand Up @@ -337,23 +359,30 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
}
Variants::Multiple {
tag_encoding:
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start },
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start, ref flag, },
tag_field,
..
} => {
let store = |bx: &mut Bx, value: u128, place: Self| {
let ty = bx.cx().immediate_backend_type(place.layout);
let val = if bx.cx().type_kind(ty) == TypeKind::Pointer {
let ty_isize = bx.cx().type_isize();
let llvalue = bx.cx().const_uint_big(ty_isize, value);
bx.inttoptr(llvalue, ty)
} else {
bx.cx().const_uint_big(ty, value)
};
OperandValue::Immediate(val).store(bx, place);
};
if variant_index != untagged_variant {
if let Some(flag) = flag {
let place = self.project_field(bx, flag.field);
store(bx, flag.magic_value, place);
}
let niche = self.project_field(bx, tag_field);
let niche_llty = bx.cx().immediate_backend_type(niche.layout);
let niche_value = variant_index.as_u32() - niche_variants.start().as_u32();
let niche_value = (niche_value as u128).wrapping_add(niche_start);
// FIXME(eddyb): check the actual primitive type here.
let niche_llval = if niche_value == 0 {
// HACK(eddyb): using `c_null` as it works on all types.
bx.cx().const_null(niche_llty)
} else {
bx.cx().const_uint_big(niche_llty, niche_value)
};
OperandValue::Immediate(niche_llval).store(bx, niche);
store(bx, niche_value, niche);
}
}
}
Expand Down
46 changes: 38 additions & 8 deletions compiler/rustc_const_eval/src/interpret/operand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -722,27 +722,55 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
// Return the cast value, and the index.
(discr_val, index.0)
}
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start } => {
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start, ref flag, } => {
let is_magic_val = if let Some(flag) = flag {
let flag_val = self.read_immediate(&self.operand_field(op, flag.field)?)?;
let flag_val = flag_val.to_scalar();
match flag_val.try_to_int() {
Err(dbg_val) => {
// So this is a pointer then, and casting to an int
// failed. Can only happen during CTFE. If the magic
// value is 0 and the scalar is not null, we know
// the pointer cannot be the magic value. Anything
// else we conservatively reject.
let ptr_definitely_not_magic_value =
flag.magic_value == 0 && !self.scalar_may_be_null(flag_val)?;
if !ptr_definitely_not_magic_value {
throw_ub!(InvalidTag(dbg_val))
}
false
}
Ok(flag_bits) => {
let flag_layout =
self.layout_of(flag.scalar.primitive().to_int_ty(*self.tcx))?;
let flag_bits = flag_bits.assert_bits(flag_layout.size);
flag_bits == flag.magic_value
}
}
} else {
true
};
let tag_val = tag_val.to_scalar();
// Compute the variant this niche value/"tag" corresponds to. With niche layout,
// discriminant (encoded in niche/tag) and variant index are the same.
let variants_start = niche_variants.start().as_u32();
let variants_end = niche_variants.end().as_u32();
let variant = match tag_val.try_to_int() {
Err(dbg_val) => {
let variant = match (is_magic_val, tag_val.try_to_int()) {
(false, _) => untagged_variant,
(true, Err(dbg_val)) => {
// So this is a pointer then, and casting to an int failed.
// Can only happen during CTFE.
// The niche must be just 0, and the ptr not null, then we know this is
// okay. Everything else, we conservatively reject.
let ptr_valid = niche_start == 0
let ptr_definitely_not_in_niche_variants = niche_start == 0
&& variants_start == variants_end
&& !self.scalar_may_be_null(tag_val)?;
if !ptr_valid {
if !ptr_definitely_not_in_niche_variants {
throw_ub!(InvalidTag(dbg_val))
}
untagged_variant
}
Ok(tag_bits) => {
(true, Ok(tag_bits)) => {
let tag_bits = tag_bits.assert_bits(tag_layout.size);
// We need to use machine arithmetic to get the relative variant idx:
// variant_index_relative = tag_val - niche_start_val
Expand Down Expand Up @@ -791,6 +819,8 @@ mod size_asserts {
// These are in alphabetical order, which is easy to maintain.
static_assert_size!(Immediate, 48);
static_assert_size!(ImmTy<'_>, 64);
static_assert_size!(Operand, 56);
static_assert_size!(OpTy<'_>, 80);
#[cfg(not(bootstrap))]
static_assert_size!(Operand, 48);
#[cfg(not(bootstrap))]
static_assert_size!(OpTy<'_>, 72);
}
11 changes: 9 additions & 2 deletions compiler/rustc_const_eval/src/interpret/place.rs
Original file line number Diff line number Diff line change
Expand Up @@ -823,15 +823,22 @@ where
}
abi::Variants::Multiple {
tag_encoding:
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start },
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start, ref flag, },
tag: tag_layout,
tag_field,
..
} => {
// No need to validate that the discriminant here because the
// No need to validate the discriminant here because the
// `TyAndLayout::for_variant()` call earlier already checks the variant is valid.

if variant_index != untagged_variant {
if let Some(flag) = flag {
let flag_layout = self.layout_of(flag.scalar.primitive().to_int_ty(*self.tcx))?;
let val = ImmTy::from_uint(flag.magic_value, flag_layout);
let flag_dest = self.place_field(dest, flag.field)?;
self.write_immediate(*val, &flag_dest)?;
}

let variants_start = niche_variants.start().as_u32();
let variant_index_relative = variant_index
.as_u32()
Expand Down
Loading