From: Philip Herron Date: Fri, 18 Jul 2025 14:46:59 +0000 (+0100) Subject: gccrs: Add initial support for deffered operator overload resolution X-Git-Url: http://git.ipfire.org/gitweb/gitweb.cgi?a=commitdiff_plain;h=9076a8f688886de933d5f953855a3431d3e0a922;p=thirdparty%2Fgcc.git gccrs: Add initial support for deffered operator overload resolution In the test case: fn test (len: usize) -> u64 { let mut i = 0; let mut out = 0; if i + 3 < len { out = 123; } out } The issue is to determine the correct type of 'i', out is simple because it hits a coercion site in the resturn position for u64. But 'i + 3', 'i' is an integer infer variable and the same for the literal '3'. So when it comes to resolving the type for the Add expression we hit the resolve the operator overload code and because of this: macro_rules! add_impl { ($($t:ty)*) => ($( impl Add for $t { type Output = $t; #[inline] #[rustc_inherit_overflow_checks] fn add(self, other: $t) -> $t { self + other } } )*) } add_impl! { usize u8 u16 u32 u64 u128 isize i8 i16 i32 i64 i128 f32 f64 } This means the resolution for 'i + 3' is ambigious because it could be any of these Add implementations. But because we unify against the '< len' where len is defined as usize later in the resolution we determine 'i' is actually a usize. Which means if we defer the resolution of this operator overload in the ambigious case we can simply resolve it at the end. Fixes Rust-GCC#3916 gcc/rust/ChangeLog: * hir/tree/rust-hir-expr.cc (OperatorExprMeta::OperatorExprMeta): track the rhs * hir/tree/rust-hir-expr.h: likewise * hir/tree/rust-hir-path.h: get rid of old comments * typecheck/rust-hir-trait-reference.cc (TraitReference::get_trait_substs): return references instead of copy * typecheck/rust-hir-trait-reference.h: update header * typecheck/rust-hir-type-check-expr.cc (TypeCheckExpr::ResolveOpOverload): write ambigious operator overloads to a table and try to resolve it at the end * typecheck/rust-hir-type-check-expr.h: new static helper * typecheck/rust-hir-type-check.h (struct DeferredOpOverload): new model to defer resolution * typecheck/rust-typecheck-context.cc (TypeCheckContext::lookup_operator_overload): new (TypeCheckContext::compute_ambigious_op_overload): likewise (TypeCheckContext::compute_inference_variables): likewise gcc/testsuite/ChangeLog: * rust/compile/issue-3916.rs: New test. Signed-off-by: Philip Herron --- diff --git a/gcc/rust/hir/tree/rust-hir-expr.cc b/gcc/rust/hir/tree/rust-hir-expr.cc index 93dcec2c8d7..038bfc77f94 100644 --- a/gcc/rust/hir/tree/rust-hir-expr.cc +++ b/gcc/rust/hir/tree/rust-hir-expr.cc @@ -17,6 +17,7 @@ // . #include "rust-hir-expr.h" +#include "rust-hir-map.h" #include "rust-operators.h" #include "rust-hir-stmt.h" @@ -1321,37 +1322,40 @@ AsyncBlockExpr::operator= (AsyncBlockExpr const &other) OperatorExprMeta::OperatorExprMeta (HIR::CompoundAssignmentExpr &expr) : node_mappings (expr.get_mappings ()), lvalue_mappings (expr.get_expr ().get_mappings ()), - locus (expr.get_locus ()) + rvalue_mappings (expr.get_rhs ().get_mappings ()), locus (expr.get_locus ()) {} OperatorExprMeta::OperatorExprMeta (HIR::ArithmeticOrLogicalExpr &expr) : node_mappings (expr.get_mappings ()), lvalue_mappings (expr.get_expr ().get_mappings ()), - locus (expr.get_locus ()) + rvalue_mappings (expr.get_rhs ().get_mappings ()), locus (expr.get_locus ()) {} OperatorExprMeta::OperatorExprMeta (HIR::NegationExpr &expr) : node_mappings (expr.get_mappings ()), lvalue_mappings (expr.get_expr ().get_mappings ()), + rvalue_mappings (Analysis::NodeMapping::get_error ()), locus (expr.get_locus ()) {} OperatorExprMeta::OperatorExprMeta (HIR::DereferenceExpr &expr) : node_mappings (expr.get_mappings ()), lvalue_mappings (expr.get_expr ().get_mappings ()), + rvalue_mappings (Analysis::NodeMapping::get_error ()), locus (expr.get_locus ()) {} OperatorExprMeta::OperatorExprMeta (HIR::ArrayIndexExpr &expr) : node_mappings (expr.get_mappings ()), lvalue_mappings (expr.get_array_expr ().get_mappings ()), + rvalue_mappings (expr.get_index_expr ().get_mappings ()), locus (expr.get_locus ()) {} OperatorExprMeta::OperatorExprMeta (HIR::ComparisonExpr &expr) : node_mappings (expr.get_mappings ()), lvalue_mappings (expr.get_expr ().get_mappings ()), - locus (expr.get_locus ()) + rvalue_mappings (expr.get_rhs ().get_mappings ()), locus (expr.get_locus ()) {} InlineAsmOperand::In::In ( diff --git a/gcc/rust/hir/tree/rust-hir-expr.h b/gcc/rust/hir/tree/rust-hir-expr.h index fcb4744fef4..028455b9870 100644 --- a/gcc/rust/hir/tree/rust-hir-expr.h +++ b/gcc/rust/hir/tree/rust-hir-expr.h @@ -27,6 +27,7 @@ #include "rust-hir-attrs.h" #include "rust-expr.h" #include "rust-hir-map.h" +#include "rust-mapping-common.h" namespace Rust { namespace HIR { @@ -2892,6 +2893,22 @@ public: OperatorExprMeta (HIR::ComparisonExpr &expr); + OperatorExprMeta (const OperatorExprMeta &other) + : node_mappings (other.node_mappings), + lvalue_mappings (other.lvalue_mappings), + rvalue_mappings (other.rvalue_mappings), locus (other.locus) + {} + + OperatorExprMeta &operator= (const OperatorExprMeta &other) + { + node_mappings = other.node_mappings; + lvalue_mappings = other.lvalue_mappings; + rvalue_mappings = other.rvalue_mappings; + locus = other.locus; + + return *this; + } + const Analysis::NodeMapping &get_mappings () const { return node_mappings; } const Analysis::NodeMapping &get_lvalue_mappings () const @@ -2899,11 +2916,22 @@ public: return lvalue_mappings; } + const Analysis::NodeMapping &get_rvalue_mappings () const + { + return rvalue_mappings; + } + + bool has_rvalue_mappings () const + { + return rvalue_mappings.get_hirid () != UNKNOWN_HIRID; + } + location_t get_locus () const { return locus; } private: - const Analysis::NodeMapping node_mappings; - const Analysis::NodeMapping lvalue_mappings; + Analysis::NodeMapping node_mappings; + Analysis::NodeMapping lvalue_mappings; + Analysis::NodeMapping rvalue_mappings; location_t locus; }; diff --git a/gcc/rust/hir/tree/rust-hir-path.h b/gcc/rust/hir/tree/rust-hir-path.h index 3ce2662c802..5f88c6827bb 100644 --- a/gcc/rust/hir/tree/rust-hir-path.h +++ b/gcc/rust/hir/tree/rust-hir-path.h @@ -41,11 +41,15 @@ public: : segment_name (std::move (segment_name)) {} - /* TODO: insert check in constructor for this? Or is this a semantic error - * best handled then? */ + PathIdentSegment (const PathIdentSegment &other) + : segment_name (other.segment_name) + {} - /* TODO: does this require visitor? pretty sure this isn't polymorphic, but - * not entirely sure */ + PathIdentSegment &operator= (PathIdentSegment const &other) + { + segment_name = other.segment_name; + return *this; + } // Creates an error PathIdentSegment. static PathIdentSegment create_error () { return PathIdentSegment (""); } diff --git a/gcc/rust/typecheck/rust-hir-trait-reference.cc b/gcc/rust/typecheck/rust-hir-trait-reference.cc index 88e270d510d..74856f098fa 100644 --- a/gcc/rust/typecheck/rust-hir-trait-reference.cc +++ b/gcc/rust/typecheck/rust-hir-trait-reference.cc @@ -432,7 +432,13 @@ TraitReference::trait_has_generics () const return !trait_substs.empty (); } -std::vector +std::vector & +TraitReference::get_trait_substs () +{ + return trait_substs; +} + +const std::vector & TraitReference::get_trait_substs () const { return trait_substs; diff --git a/gcc/rust/typecheck/rust-hir-trait-reference.h b/gcc/rust/typecheck/rust-hir-trait-reference.h index 8b1ac7daf7f..473513ea75f 100644 --- a/gcc/rust/typecheck/rust-hir-trait-reference.h +++ b/gcc/rust/typecheck/rust-hir-trait-reference.h @@ -224,7 +224,9 @@ public: bool trait_has_generics () const; - std::vector get_trait_substs () const; + std::vector &get_trait_substs (); + + const std::vector &get_trait_substs () const; bool satisfies_bound (const TraitReference &reference) const; diff --git a/gcc/rust/typecheck/rust-hir-type-check-expr.cc b/gcc/rust/typecheck/rust-hir-type-check-expr.cc index c1404561f4d..5db0e5690c9 100644 --- a/gcc/rust/typecheck/rust-hir-type-check-expr.cc +++ b/gcc/rust/typecheck/rust-hir-type-check-expr.cc @@ -17,6 +17,7 @@ // . #include "optional.h" +#include "rust-common.h" #include "rust-hir-expr.h" #include "rust-system.h" #include "rust-tyty-call.h" @@ -59,6 +60,19 @@ TypeCheckExpr::Resolve (HIR::Expr &expr) return resolver.infered; } +TyTy::BaseType * +TypeCheckExpr::ResolveOpOverload (LangItem::Kind lang_item_type, + HIR::OperatorExprMeta expr, + TyTy::BaseType *lhs, TyTy::BaseType *rhs, + HIR::PathIdentSegment specified_segment) +{ + TypeCheckExpr resolver; + + resolver.resolve_operator_overload (lang_item_type, expr, lhs, rhs, + specified_segment); + return resolver.infered; +} + void TypeCheckExpr::visit (HIR::TupleIndexExpr &expr) { @@ -1885,7 +1899,16 @@ TypeCheckExpr::resolve_operator_overload ( // probe for the lang-item if (!lang_item_defined) return false; + DefId &respective_lang_item_id = lang_item_defined.value (); + auto def_lookup = mappings.lookup_defid (respective_lang_item_id); + rust_assert (def_lookup.has_value ()); + + HIR::Item *def_item = def_lookup.value (); + rust_assert (def_item->get_item_kind () == HIR::Item::ItemKind::Trait); + HIR::Trait &trait = *static_cast (def_item); + TraitReference *defid_trait_reference = TraitResolver::Resolve (trait); + rust_assert (!defid_trait_reference->is_error ()); // we might be in a static or const context and unknown is fine TypeCheckContextItem current_context = TypeCheckContextItem::get_error (); @@ -1929,15 +1952,49 @@ TypeCheckExpr::resolve_operator_overload ( if (selected_candidates.size () > 1) { - // mutliple candidates - rich_location r (line_table, expr.get_locus ()); - for (auto &c : resolved_candidates) - r.add_range (c.candidate.locus); + auto infer + = TyTy::TyVar::get_implicit_infer_var (expr.get_locus ()).get_tyty (); + auto trait_subst = defid_trait_reference->get_trait_substs (); + rust_assert (trait_subst.size () > 0); - rust_error_at ( - r, "multiple candidates found for possible operator overload"); + TyTy::TypeBoundPredicate pred (respective_lang_item_id, trait_subst, + BoundPolarity::RegularBound, + expr.get_locus ()); - return false; + std::vector mappings; + auto &self_param_mapping = trait_subst[0]; + mappings.push_back (TyTy::SubstitutionArg (&self_param_mapping, lhs)); + + if (rhs != nullptr) + { + rust_assert (trait_subst.size () == 2); + auto &rhs_param_mapping = trait_subst[1]; + mappings.push_back (TyTy::SubstitutionArg (&rhs_param_mapping, lhs)); + } + + std::map binding_args; + binding_args["Output"] = infer; + + TyTy::SubstitutionArgumentMappings arg_mappings (mappings, binding_args, + TyTy::RegionParamList ( + trait_subst.size ()), + expr.get_locus ()); + pred.apply_argument_mappings (arg_mappings, false); + + infer->inherit_bounds ({pred}); + DeferredOpOverload defer (expr.get_mappings ().get_hirid (), + lang_item_type, specified_segment, pred, expr); + context->insert_deferred_operator_overload (std::move (defer)); + + if (rhs != nullptr) + lhs = unify_site (expr.get_mappings ().get_hirid (), + TyTy::TyWithLocation (lhs), + TyTy::TyWithLocation (rhs), expr.get_locus ()); + + infered = unify_site (expr.get_mappings ().get_hirid (), + TyTy::TyWithLocation (lhs), + TyTy::TyWithLocation (infer), expr.get_locus ()); + return true; } // Get the adjusted self diff --git a/gcc/rust/typecheck/rust-hir-type-check-expr.h b/gcc/rust/typecheck/rust-hir-type-check-expr.h index 53119743685..48f28c70079 100644 --- a/gcc/rust/typecheck/rust-hir-type-check-expr.h +++ b/gcc/rust/typecheck/rust-hir-type-check-expr.h @@ -31,6 +31,11 @@ class TypeCheckExpr : private TypeCheckBase, private HIR::HIRExpressionVisitor public: static TyTy::BaseType *Resolve (HIR::Expr &expr); + static TyTy::BaseType * + ResolveOpOverload (LangItem::Kind lang_item_type, HIR::OperatorExprMeta expr, + TyTy::BaseType *lhs, TyTy::BaseType *rhs, + HIR::PathIdentSegment specified_segment); + void visit (HIR::TupleIndexExpr &expr) override; void visit (HIR::TupleExpr &expr) override; void visit (HIR::ReturnExpr &expr) override; diff --git a/gcc/rust/typecheck/rust-hir-type-check.h b/gcc/rust/typecheck/rust-hir-type-check.h index 356c55803ed..80e40344835 100644 --- a/gcc/rust/typecheck/rust-hir-type-check.h +++ b/gcc/rust/typecheck/rust-hir-type-check.h @@ -20,6 +20,7 @@ #define RUST_HIR_TYPE_CHECK #include "rust-hir-map.h" +#include "rust-mapping-common.h" #include "rust-tyty.h" #include "rust-hir-trait-reference.h" #include "rust-stacked-contexts.h" @@ -157,6 +158,39 @@ public: WARN_UNUSED_RESULT Lifetime next () { return Lifetime (interner_index++); } }; +struct DeferredOpOverload +{ + HirId expr_id; + LangItem::Kind lang_item_type; + HIR::PathIdentSegment specified_segment; + TyTy::TypeBoundPredicate predicate; + HIR::OperatorExprMeta op; + + DeferredOpOverload (HirId expr_id, LangItem::Kind lang_item_type, + HIR::PathIdentSegment specified_segment, + TyTy::TypeBoundPredicate &predicate, + HIR::OperatorExprMeta op) + : expr_id (expr_id), lang_item_type (lang_item_type), + specified_segment (specified_segment), predicate (predicate), op (op) + {} + + DeferredOpOverload (const struct DeferredOpOverload &other) + : expr_id (other.expr_id), lang_item_type (other.lang_item_type), + specified_segment (other.specified_segment), predicate (other.predicate), + op (other.op) + {} + + DeferredOpOverload &operator= (struct DeferredOpOverload const &other) + { + expr_id = other.expr_id; + lang_item_type = other.lang_item_type; + specified_segment = other.specified_segment; + op = other.op; + + return *this; + } +}; + class TypeCheckContext { public: @@ -237,6 +271,13 @@ public: void insert_operator_overload (HirId id, TyTy::FnType *call_site); bool lookup_operator_overload (HirId id, TyTy::FnType **call); + void insert_deferred_operator_overload (DeferredOpOverload deferred); + bool lookup_deferred_operator_overload (HirId id, + DeferredOpOverload *deferred); + + void iterate_deferred_operator_overloads ( + std::function cb); + void insert_unconstrained_check_marker (HirId id, bool status); bool have_checked_for_unconstrained (HirId id, bool *result); @@ -271,6 +312,7 @@ private: TypeCheckContext (); bool compute_infer_var (HirId id, TyTy::BaseType *ty, bool emit_error); + bool compute_ambigious_op_overload (HirId id, DeferredOpOverload &op); std::map node_id_refs; std::map resolved; @@ -308,6 +350,9 @@ private: std::set querys_in_progress; std::set trait_queries_in_progress; + // deferred operator overload + std::map deferred_operator_overloads; + // variance analysis TyTy::VarianceAnalysis::CrateCtx variance_analysis_ctx; diff --git a/gcc/rust/typecheck/rust-typecheck-context.cc b/gcc/rust/typecheck/rust-typecheck-context.cc index 7b3584823e4..83b17612d5e 100644 --- a/gcc/rust/typecheck/rust-typecheck-context.cc +++ b/gcc/rust/typecheck/rust-typecheck-context.cc @@ -18,6 +18,7 @@ #include "rust-hir-type-check.h" #include "rust-type-util.h" +#include "rust-hir-type-check-expr.h" namespace Rust { namespace Resolver { @@ -408,6 +409,38 @@ TypeCheckContext::lookup_operator_overload (HirId id, TyTy::FnType **call) return true; } +void +TypeCheckContext::insert_deferred_operator_overload ( + DeferredOpOverload deferred) +{ + HirId expr_id = deferred.expr_id; + deferred_operator_overloads.emplace (std::make_pair (expr_id, deferred)); +} + +bool +TypeCheckContext::lookup_deferred_operator_overload ( + HirId id, DeferredOpOverload *deferred) +{ + auto it = deferred_operator_overloads.find (id); + if (it == deferred_operator_overloads.end ()) + return false; + + *deferred = it->second; + return true; +} + +void +TypeCheckContext::iterate_deferred_operator_overloads ( + std::function cb) +{ + for (auto it = deferred_operator_overloads.begin (); + it != deferred_operator_overloads.end (); it++) + { + if (!cb (it->first, it->second)) + return; + } +} + void TypeCheckContext::insert_unconstrained_check_marker (HirId id, bool status) { @@ -574,10 +607,38 @@ TypeCheckContext::regions_from_generic_args (const HIR::GenericArgs &args) const return regions; } +bool +TypeCheckContext::compute_ambigious_op_overload (HirId id, + DeferredOpOverload &op) +{ + rust_debug ("attempting resolution of op overload: %s", + op.predicate.as_string ().c_str ()); + + TyTy::BaseType *lhs = nullptr; + bool ok = lookup_type (op.op.get_lvalue_mappings ().get_hirid (), &lhs); + rust_assert (ok); + + TyTy::BaseType *rhs = nullptr; + if (op.op.has_rvalue_mappings ()) + { + bool ok = lookup_type (op.op.get_rvalue_mappings ().get_hirid (), &rhs); + rust_assert (ok); + } + + TypeCheckExpr::ResolveOpOverload (op.lang_item_type, op.op, lhs, rhs, + op.specified_segment); + + return true; +} + void TypeCheckContext::compute_inference_variables (bool emit_error) { - // default inference variables if possible + iterate_deferred_operator_overloads ( + [&] (HirId id, DeferredOpOverload &op) mutable -> bool { + return compute_ambigious_op_overload (id, op); + }); + iterate ([&] (HirId id, TyTy::BaseType *ty) mutable -> bool { return compute_infer_var (id, ty, emit_error); }); diff --git a/gcc/testsuite/rust/compile/issue-3916.rs b/gcc/testsuite/rust/compile/issue-3916.rs new file mode 100644 index 00000000000..59b522b4ed5 --- /dev/null +++ b/gcc/testsuite/rust/compile/issue-3916.rs @@ -0,0 +1,36 @@ +#![feature(rustc_attrs)] + +#[lang = "sized"] +trait Sized {} + +#[lang = "add"] +trait Add { + type Output; + + fn add(self, rhs: Rhs) -> Self::Output; +} + +macro_rules! add_impl { + ($($t:ty)*) => ($( + impl Add for $t { + type Output = $t; + + #[inline] + #[rustc_inherit_overflow_checks] + fn add(self, other: $t) -> $t { self + other } + } + )*) +} + +add_impl! { usize u8 u16 u32 u64 u128 isize i8 i16 i32 i64 i128 f32 f64 } + +pub fn test(len: usize) -> u64 { + let mut i = 0; + let mut out = 0; + if i + 3 < len { + out = 123; + } else { + out = 456; + } + out +}