Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fallback {float} to f32 when f32: From<{float}> and add impl From<f16> for f32 #139087

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions compiler/rustc_hir/src/lang_items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,9 @@ language_item_table! {
DefaultTrait3, sym::default_trait3, default_trait3_trait, Target::Trait, GenericRequirement::None;
DefaultTrait2, sym::default_trait2, default_trait2_trait, Target::Trait, GenericRequirement::None;
DefaultTrait1, sym::default_trait1, default_trait1_trait, Target::Trait, GenericRequirement::None;

// Used to fallback `{float}` to `f32` when `f32: From<{float}>`
From, sym::From, from_trait, Target::Trait, GenericRequirement::Exact(1);
}

/// The requirement imposed on the generics of a lang item
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_hir_typeck/messages.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ hir_typeck_field_multiply_specified_in_initializer =
.label = used more than once
.previous_use_label = first use of `{$ident}`

hir_typeck_float_literal_f32_fallback =
falling back to `f32` as the trait bound `f32: From<f64>` is not satisfied
.suggestion = explicitly specify the type as `f32`

hir_typeck_fn_item_to_variadic_function = can't pass a function item to a variadic function
.suggestion = use a function pointer instead
.help = a function item is zero-sized and needs to be cast into a function pointer to be used in FFI
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_hir_typeck/src/demand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
match infer {
ty::TyVar(_) => self.next_ty_var(DUMMY_SP),
ty::IntVar(_) => self.next_int_var(),
ty::FloatVar(_) => self.next_float_var(),
ty::FloatVar(_) => self.next_float_var(DUMMY_SP),
ty::FreshTy(_) | ty::FreshIntTy(_) | ty::FreshFloatTy(_) => {
bug!("unexpected fresh ty outside of the trait solver")
}
Expand Down
8 changes: 8 additions & 0 deletions compiler/rustc_hir_typeck/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -979,3 +979,11 @@ pub(crate) enum SupertraitItemShadowee {
traits: DiagSymbolList,
},
}

#[derive(LintDiagnostic)]
#[diag(hir_typeck_float_literal_f32_fallback)]
pub(crate) struct FloatLiteralF32Fallback {
pub literal: String,
#[suggestion(code = "{literal}_f32", applicability = "machine-applicable")]
pub span: Option<Span>,
}
95 changes: 90 additions & 5 deletions compiler/rustc_hir_typeck/src/fallback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@ use rustc_data_structures::graph::iterate::DepthFirstSearch;
use rustc_data_structures::graph::vec_graph::VecGraph;
use rustc_data_structures::graph::{self};
use rustc_data_structures::unord::{UnordBag, UnordMap, UnordSet};
use rustc_hir as hir;
use rustc_hir::HirId;
use rustc_hir::def::{DefKind, Res};
use rustc_hir::def_id::DefId;
use rustc_hir::intravisit::{InferKind, Visitor};
use rustc_middle::ty::{self, Ty, TyCtxt, TypeSuperVisitable, TypeVisitable};
use rustc_hir::{self as hir, CRATE_HIR_ID, HirId};
use rustc_lint::builtin::FLOAT_LITERAL_F32_FALLBACK;
use rustc_middle::ty::{
self, ClauseKind, FloatVid, PredicatePolarity, TraitPredicate, Ty, TyCtxt, TypeSuperVisitable,
TypeVisitable,
};
use rustc_session::lint;
use rustc_span::def_id::LocalDefId;
use rustc_span::{DUMMY_SP, Span};
Expand Down Expand Up @@ -92,14 +95,16 @@ impl<'tcx> FnCtxt<'_, 'tcx> {

let diverging_fallback = self
.calculate_diverging_fallback(&unresolved_variables, self.diverging_fallback_behavior);
let fallback_to_f32 = self.calculate_fallback_to_f32(&unresolved_variables);

// We do fallback in two passes, to try to generate
// better error messages.
// The first time, we do *not* replace opaque types.
let mut fallback_occurred = false;
for ty in unresolved_variables {
debug!("unsolved_variable = {:?}", ty);
fallback_occurred |= self.fallback_if_possible(ty, &diverging_fallback);
fallback_occurred |=
self.fallback_if_possible(ty, &diverging_fallback, &fallback_to_f32);
}

fallback_occurred
Expand All @@ -109,7 +114,8 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
//
// - Unconstrained ints are replaced with `i32`.
//
// - Unconstrained floats are replaced with `f64`.
// - Unconstrained floats are replaced with `f64`, except when there is a trait predicate
// `f32: From<{float}>`, in which case `f32` is used as the fallback instead.
//
// - Non-numerics may get replaced with `()` or `!`, depending on
// how they were categorized by `calculate_diverging_fallback`
Expand All @@ -124,6 +130,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
&self,
ty: Ty<'tcx>,
diverging_fallback: &UnordMap<Ty<'tcx>, Ty<'tcx>>,
fallback_to_f32: &UnordSet<FloatVid>,
) -> bool {
// Careful: we do NOT shallow-resolve `ty`. We know that `ty`
// is an unsolved variable, and we determine its fallback
Expand All @@ -146,6 +153,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
let fallback = match ty.kind() {
_ if let Some(e) = self.tainted_by_errors() => Ty::new_error(self.tcx, e),
ty::Infer(ty::IntVar(_)) => self.tcx.types.i32,
ty::Infer(ty::FloatVar(vid)) if fallback_to_f32.contains(vid) => self.tcx.types.f32,
ty::Infer(ty::FloatVar(_)) => self.tcx.types.f64,
_ => match diverging_fallback.get(&ty) {
Some(&fallback_ty) => fallback_ty,
Expand All @@ -160,6 +168,78 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
true
}

/// Existing code relies on `f32: From<T>` (usually written as `T: Into<f32>`) resolving `T` to
/// `f32` when the type of `T` is inferred from an unsuffixed float literal. Using the default
/// fallback of `f64`, this would break when adding `impl From<f16> for f32`, as there are now
/// two float type which could be `T`, meaning that the fallback of `f64` would be used and
/// compilation error would occur as `f32` does not implement `From<f64>`. To avoid breaking
/// existing code, we instead fallback `T` to `f32` when there is a trait predicate
/// `f32: From<T>`. This means code like the following will continue to compile:
///
/// ```rust
/// fn foo<T: Into<f32>>(_: T) {}
///
/// foo(1.0);
/// ```
fn calculate_fallback_to_f32(&self, unresolved_variables: &[Ty<'tcx>]) -> UnordSet<FloatVid> {
let Some(from_trait) = self.tcx.lang_items().from_trait() else {
return UnordSet::new();
};
let pending_obligations = self.fulfillment_cx.borrow_mut().pending_obligations();
debug!("calculate_fallback_to_f32: pending_obligations={:?}", pending_obligations);
let roots: UnordSet<ty::FloatVid> = pending_obligations
.into_iter()
.filter_map(|obligation| {
// The predicates we are looking for look like
// `TraitPredicate(<f32 as std::convert::From<{float}>>, polarity:Positive)`.
// They will have no bound variables.
obligation.predicate.kind().no_bound_vars()
})
.filter_map(|predicate| match predicate {
ty::PredicateKind::Clause(ClauseKind::Trait(TraitPredicate {
polarity: PredicatePolarity::Positive,
trait_ref,
})) if trait_ref.def_id == from_trait
&& self.shallow_resolve(trait_ref.self_ty()).kind()
== &ty::Float(ty::FloatTy::F32) =>
{
self.root_float_vid(trait_ref.args.type_at(1))
}
_ => None,
})
.collect();
debug!("calculate_fallback_to_f32: roots={:?}", roots);
if roots.is_empty() {
// Most functions have no `f32: From<{float}>` predicates, so short-circuit and return
// an empty set when this is the case.
return UnordSet::new();
}
// Calculate all the unresolved variables that need to fallback to `f32` here. This ensures
// we don't need to find root variables in `fallback_if_possible`: see the comment at the
// top of that function for details.
let fallback_to_f32 = unresolved_variables
.iter()
.flat_map(|ty| ty.float_vid())
.filter(|vid| roots.contains(&self.root_float_var(*vid)))
.inspect(|vid| {
let span = self.float_var_origin(*vid);
// Show the entire literal in the suggestion to make it clearer.
let literal = self.tcx.sess.source_map().span_to_snippet(span).ok();
self.tcx.emit_node_span_lint(
FLOAT_LITERAL_F32_FALLBACK,
CRATE_HIR_ID,
span,
errors::FloatLiteralF32Fallback {
span: literal.as_ref().map(|_| span),
literal: literal.unwrap_or_default(),
},
);
})
.collect();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do the mapping from root to all unified infer vars here instead of looking at the root in fallback_if_possible? Please add a comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a comment. I put that there because that's what calculate_diverging_fallback does and there's a comment at the top of fallback_if_possible stating that "we determine its fallback based solely on how it was created, not what other type variables it may have been unified with since then", saying it makes it easier to detect bugs.

debug!("calculate_fallback_to_f32: fallback_to_f32={:?}", fallback_to_f32);
fallback_to_f32
}

/// The "diverging fallback" system is rather complicated. This is
/// a result of our need to balance 'do the right thing' with
/// backwards compatibility.
Expand Down Expand Up @@ -565,6 +645,11 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
Some(self.root_var(self.shallow_resolve(ty).ty_vid()?))
}

/// If `ty` is an unresolved float type variable, returns its root vid.
fn root_float_vid(&self, ty: Ty<'tcx>) -> Option<ty::FloatVid> {
Some(self.root_float_var(self.shallow_resolve(ty).float_vid()?))
}

/// Given a set of diverging vids and coercions, walk the HIR to gather a
/// set of suggestions which can be applied to preserve fallback to unit.
fn try_to_suggest_annotations(
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1669,7 +1669,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
ty::Float(_) => Some(ty),
_ => None,
});
opt_ty.unwrap_or_else(|| self.next_float_var())
opt_ty.unwrap_or_else(|| self.next_float_var(lit.span))
}
ast::LitKind::Bool(_) => tcx.types.bool,
ast::LitKind::CStr(_, _) => Ty::new_imm_ref(
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_infer/src/infer/canonical/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ impl<'tcx> InferCtxt<'tcx> {

CanonicalTyVarKind::Int => self.next_int_var(),

CanonicalTyVarKind::Float => self.next_float_var(),
CanonicalTyVarKind::Float => self.next_float_var(span),
};
ty.into()
}
Expand Down
25 changes: 22 additions & 3 deletions compiler/rustc_infer/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use rustc_data_structures::unify as ut;
use rustc_errors::{DiagCtxtHandle, ErrorGuaranteed};
use rustc_hir as hir;
use rustc_hir::def_id::{DefId, LocalDefId};
use rustc_index::IndexVec;
use rustc_macros::extension;
pub use rustc_macros::{TypeFoldable, TypeVisitable};
use rustc_middle::bug;
Expand Down Expand Up @@ -109,6 +110,10 @@ pub struct InferCtxtInner<'tcx> {
/// Map from floating variable to the kind of float it represents.
float_unification_storage: ut::UnificationTableStorage<ty::FloatVid>,

/// Map from floating variable to the origin span it came from. This is only used for the FCW
/// for the fallback to `f32`, so can be removed once the `f32` fallback is removed.
float_origin_span_storage: IndexVec<FloatVid, Span>,

/// Tracks the set of region variables and the constraints between them.
///
/// This is initially `Some(_)` but when
Expand Down Expand Up @@ -165,6 +170,7 @@ impl<'tcx> InferCtxtInner<'tcx> {
const_unification_storage: Default::default(),
int_unification_storage: Default::default(),
float_unification_storage: Default::default(),
float_origin_span_storage: Default::default(),
region_constraint_storage: Some(Default::default()),
region_obligations: vec![],
opaque_type_storage: Default::default(),
Expand Down Expand Up @@ -633,6 +639,13 @@ impl<'tcx> InferCtxt<'tcx> {
self.inner.borrow_mut().type_variables().var_origin(vid)
}

/// Returns the origin of the float type variable identified by `vid`.
///
/// No attempt is made to resolve `vid` to its root variable.
pub fn float_var_origin(&self, vid: FloatVid) -> Span {
self.inner.borrow_mut().float_origin_span_storage[vid]
}

/// Returns the origin of the const variable identified by `vid`
// FIXME: We should store origins separately from the unification table
// so this doesn't need to be optional.
Expand Down Expand Up @@ -818,9 +831,11 @@ impl<'tcx> InferCtxt<'tcx> {
Ty::new_int_var(self.tcx, next_int_var_id)
}

pub fn next_float_var(&self) -> Ty<'tcx> {
let next_float_var_id =
self.inner.borrow_mut().float_unification_table().new_key(ty::FloatVarValue::Unknown);
pub fn next_float_var(&self, span: Span) -> Ty<'tcx> {
let mut inner = self.inner.borrow_mut();
let next_float_var_id = inner.float_unification_table().new_key(ty::FloatVarValue::Unknown);
let span_index = inner.float_origin_span_storage.push(span);
debug_assert_eq!(next_float_var_id, span_index);
Ty::new_float_var(self.tcx, next_float_var_id)
}

Expand Down Expand Up @@ -1061,6 +1076,10 @@ impl<'tcx> InferCtxt<'tcx> {
self.inner.borrow_mut().type_variables().root_var(var)
}

pub fn root_float_var(&self, var: ty::FloatVid) -> ty::FloatVid {
self.inner.borrow_mut().float_unification_table().find(var)
}

pub fn root_const_var(&self, var: ty::ConstVid) -> ty::ConstVid {
self.inner.borrow_mut().const_unification_table().find(var).vid
}
Expand Down
26 changes: 19 additions & 7 deletions compiler/rustc_infer/src/infer/snapshot/fudge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@ use rustc_middle::ty::{
self, ConstVid, FloatVid, IntVid, RegionVid, Ty, TyCtxt, TyVid, TypeFoldable, TypeFolder,
TypeSuperFoldable,
};
use rustc_span::Span;
use rustc_type_ir::TypeVisitableExt;
use tracing::instrument;
use ut::UnifyKey;

use super::VariableLengths;
use crate::infer::type_variable::TypeVariableOrigin;
use crate::infer::unify_key::{ConstVariableValue, ConstVidKey};
use crate::infer::{ConstVariableOrigin, InferCtxt, RegionVariableOrigin, UnificationTable};
use crate::infer::{
ConstVariableOrigin, InferCtxt, InferCtxtInner, RegionVariableOrigin, UnificationTable,
};

fn vars_since_snapshot<'tcx, T>(
table: &UnificationTable<'_, 'tcx, T>,
Expand All @@ -25,6 +28,14 @@ where
T::from_index(snapshot_var_len as u32)..T::from_index(table.len() as u32)
}

fn float_vars_since_snapshot(
inner: &mut InferCtxtInner<'_>,
snapshot_var_len: usize,
) -> (Range<FloatVid>, Vec<Span>) {
let range = vars_since_snapshot(&inner.float_unification_table(), snapshot_var_len);
(range.clone(), range.map(|index| inner.float_origin_span_storage[index]).collect())
}

fn const_vars_since_snapshot<'tcx>(
table: &mut UnificationTable<'_, 'tcx, ConstVidKey<'tcx>>,
snapshot_var_len: usize,
Expand Down Expand Up @@ -128,7 +139,7 @@ struct SnapshotVarData {
region_vars: (Range<RegionVid>, Vec<RegionVariableOrigin>),
type_vars: (Range<TyVid>, Vec<TypeVariableOrigin>),
int_vars: Range<IntVid>,
float_vars: Range<FloatVid>,
float_vars: (Range<FloatVid>, Vec<Span>),
const_vars: (Range<ConstVid>, Vec<ConstVariableOrigin>),
}

Expand All @@ -141,8 +152,7 @@ impl SnapshotVarData {
let type_vars = inner.type_variables().vars_since_snapshot(vars_pre_snapshot.type_var_len);
let int_vars =
vars_since_snapshot(&inner.int_unification_table(), vars_pre_snapshot.int_var_len);
let float_vars =
vars_since_snapshot(&inner.float_unification_table(), vars_pre_snapshot.float_var_len);
let float_vars = float_vars_since_snapshot(&mut inner, vars_pre_snapshot.float_var_len);

let const_vars = const_vars_since_snapshot(
&mut inner.const_unification_table(),
Expand All @@ -156,7 +166,7 @@ impl SnapshotVarData {
region_vars.0.is_empty()
&& type_vars.0.is_empty()
&& int_vars.is_empty()
&& float_vars.is_empty()
&& float_vars.0.is_empty()
&& const_vars.0.is_empty()
}
}
Expand Down Expand Up @@ -201,8 +211,10 @@ impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for InferenceFudger<'a, 'tcx> {
}
}
ty::FloatVar(vid) => {
if self.snapshot_vars.float_vars.contains(&vid) {
self.infcx.next_float_var()
if self.snapshot_vars.float_vars.0.contains(&vid) {
let idx = vid.as_usize() - self.snapshot_vars.float_vars.0.start.as_usize();
let span = self.snapshot_vars.float_vars.1[idx];
self.infcx.next_float_var(span)
} else {
ty
}
Expand Down
Loading
Loading