]> git.ipfire.org Git - thirdparty/gcc.git/commitdiff
gccrs: Add initial support for deffered operator overload resolution
authorPhilip Herron <herron.philip@googlemail.com>
Fri, 18 Jul 2025 14:46:59 +0000 (15:46 +0100)
committerArthur Cohen <arthur.cohen@embecosm.com>
Tue, 5 Aug 2025 14:36:56 +0000 (16:36 +0200)
In the test case:

  fn test (len: usize) -> u64 {
     let mut i = 0;
     let mut out = 0;
     if i + 3 < len {
        out = 123;
     }
     out
  }

The issue is to determine the correct type of 'i', out is simple because it hits a
coercion site in the resturn position for u64. But 'i + 3', 'i' is an integer infer
variable and the same for the literal '3'. So when it comes to resolving the type for
the Add expression we hit the resolve the operator overload code and because of this:

  macro_rules! add_impl {
      ($($t:ty)*) => ($(
          impl Add for $t {
              type Output = $t;

              #[inline]
              #[rustc_inherit_overflow_checks]
              fn add(self, other: $t) -> $t { self + other }
          }
      )*)
  }

  add_impl! { usize u8 u16 u32 u64 u128 isize i8 i16 i32 i64 i128 f32 f64 }

This means the resolution for 'i + 3' is ambigious because it could be any of these Add
implementations. But because we unify against the '< len' where len is defined as usize
later in the resolution we determine 'i' is actually a usize. Which means if we defer the
resolution of this operator overload in the ambigious case we can simply resolve it at the
end.

Fixes Rust-GCC#3916

gcc/rust/ChangeLog:

* hir/tree/rust-hir-expr.cc (OperatorExprMeta::OperatorExprMeta): track the rhs
* hir/tree/rust-hir-expr.h: likewise
* hir/tree/rust-hir-path.h: get rid of old comments
* typecheck/rust-hir-trait-reference.cc (TraitReference::get_trait_substs): return
references instead of copy
* typecheck/rust-hir-trait-reference.h: update header
* typecheck/rust-hir-type-check-expr.cc (TypeCheckExpr::ResolveOpOverload): write ambigious
operator overloads to a table and try to resolve it at the end
* typecheck/rust-hir-type-check-expr.h: new static helper
* typecheck/rust-hir-type-check.h (struct DeferredOpOverload): new model to defer resolution
* typecheck/rust-typecheck-context.cc (TypeCheckContext::lookup_operator_overload): new
(TypeCheckContext::compute_ambigious_op_overload): likewise
(TypeCheckContext::compute_inference_variables): likewise

gcc/testsuite/ChangeLog:

* rust/compile/issue-3916.rs: New test.

Signed-off-by: Philip Herron <herron.philip@googlemail.com>
gcc/rust/hir/tree/rust-hir-expr.cc
gcc/rust/hir/tree/rust-hir-expr.h
gcc/rust/hir/tree/rust-hir-path.h
gcc/rust/typecheck/rust-hir-trait-reference.cc
gcc/rust/typecheck/rust-hir-trait-reference.h
gcc/rust/typecheck/rust-hir-type-check-expr.cc
gcc/rust/typecheck/rust-hir-type-check-expr.h
gcc/rust/typecheck/rust-hir-type-check.h
gcc/rust/typecheck/rust-typecheck-context.cc
gcc/testsuite/rust/compile/issue-3916.rs [new file with mode: 0644]

index 93dcec2c8d7945045a90bb9174ecd033b233fed5..038bfc77f94a1cfe4f614fef8385ba6e9818a0e0 100644 (file)
@@ -17,6 +17,7 @@
 // <http://www.gnu.org/licenses/>.
 
 #include "rust-hir-expr.h"
+#include "rust-hir-map.h"
 #include "rust-operators.h"
 #include "rust-hir-stmt.h"
 
@@ -1321,37 +1322,40 @@ AsyncBlockExpr::operator= (AsyncBlockExpr const &other)
 OperatorExprMeta::OperatorExprMeta (HIR::CompoundAssignmentExpr &expr)
   : node_mappings (expr.get_mappings ()),
     lvalue_mappings (expr.get_expr ().get_mappings ()),
-    locus (expr.get_locus ())
+    rvalue_mappings (expr.get_rhs ().get_mappings ()), locus (expr.get_locus ())
 {}
 
 OperatorExprMeta::OperatorExprMeta (HIR::ArithmeticOrLogicalExpr &expr)
   : node_mappings (expr.get_mappings ()),
     lvalue_mappings (expr.get_expr ().get_mappings ()),
-    locus (expr.get_locus ())
+    rvalue_mappings (expr.get_rhs ().get_mappings ()), locus (expr.get_locus ())
 {}
 
 OperatorExprMeta::OperatorExprMeta (HIR::NegationExpr &expr)
   : node_mappings (expr.get_mappings ()),
     lvalue_mappings (expr.get_expr ().get_mappings ()),
+    rvalue_mappings (Analysis::NodeMapping::get_error ()),
     locus (expr.get_locus ())
 {}
 
 OperatorExprMeta::OperatorExprMeta (HIR::DereferenceExpr &expr)
   : node_mappings (expr.get_mappings ()),
     lvalue_mappings (expr.get_expr ().get_mappings ()),
+    rvalue_mappings (Analysis::NodeMapping::get_error ()),
     locus (expr.get_locus ())
 {}
 
 OperatorExprMeta::OperatorExprMeta (HIR::ArrayIndexExpr &expr)
   : node_mappings (expr.get_mappings ()),
     lvalue_mappings (expr.get_array_expr ().get_mappings ()),
+    rvalue_mappings (expr.get_index_expr ().get_mappings ()),
     locus (expr.get_locus ())
 {}
 
 OperatorExprMeta::OperatorExprMeta (HIR::ComparisonExpr &expr)
   : node_mappings (expr.get_mappings ()),
     lvalue_mappings (expr.get_expr ().get_mappings ()),
-    locus (expr.get_locus ())
+    rvalue_mappings (expr.get_rhs ().get_mappings ()), locus (expr.get_locus ())
 {}
 
 InlineAsmOperand::In::In (
index fcb4744fef4cf6adfa7a6e62f73b2683857d6525..028455b987023b989beece76d63308e362e361c5 100644 (file)
@@ -27,6 +27,7 @@
 #include "rust-hir-attrs.h"
 #include "rust-expr.h"
 #include "rust-hir-map.h"
+#include "rust-mapping-common.h"
 
 namespace Rust {
 namespace HIR {
@@ -2892,6 +2893,22 @@ public:
 
   OperatorExprMeta (HIR::ComparisonExpr &expr);
 
+  OperatorExprMeta (const OperatorExprMeta &other)
+    : node_mappings (other.node_mappings),
+      lvalue_mappings (other.lvalue_mappings),
+      rvalue_mappings (other.rvalue_mappings), locus (other.locus)
+  {}
+
+  OperatorExprMeta &operator= (const OperatorExprMeta &other)
+  {
+    node_mappings = other.node_mappings;
+    lvalue_mappings = other.lvalue_mappings;
+    rvalue_mappings = other.rvalue_mappings;
+    locus = other.locus;
+
+    return *this;
+  }
+
   const Analysis::NodeMapping &get_mappings () const { return node_mappings; }
 
   const Analysis::NodeMapping &get_lvalue_mappings () const
@@ -2899,11 +2916,22 @@ public:
     return lvalue_mappings;
   }
 
+  const Analysis::NodeMapping &get_rvalue_mappings () const
+  {
+    return rvalue_mappings;
+  }
+
+  bool has_rvalue_mappings () const
+  {
+    return rvalue_mappings.get_hirid () != UNKNOWN_HIRID;
+  }
+
   location_t get_locus () const { return locus; }
 
 private:
-  const Analysis::NodeMapping node_mappings;
-  const Analysis::NodeMapping lvalue_mappings;
+  Analysis::NodeMapping node_mappings;
+  Analysis::NodeMapping lvalue_mappings;
+  Analysis::NodeMapping rvalue_mappings;
   location_t locus;
 };
 
index 3ce2662c8024e1e836584a6556e66b9f447c87cb..5f88c6827bb1d178364c1237e5b78c2b3cf86bd8 100644 (file)
@@ -41,11 +41,15 @@ public:
     : segment_name (std::move (segment_name))
   {}
 
-  /* TODO: insert check in constructor for this? Or is this a semantic error
-   * best handled then? */
+  PathIdentSegment (const PathIdentSegment &other)
+    : segment_name (other.segment_name)
+  {}
 
-  /* TODO: does this require visitor? pretty sure this isn't polymorphic, but
-   * not entirely sure */
+  PathIdentSegment &operator= (PathIdentSegment const &other)
+  {
+    segment_name = other.segment_name;
+    return *this;
+  }
 
   // Creates an error PathIdentSegment.
   static PathIdentSegment create_error () { return PathIdentSegment (""); }
index 88e270d510d2f7013377435693a959aef8306c83..74856f098fa0e02368ea695b42a0ca6460b2e8aa 100644 (file)
@@ -432,7 +432,13 @@ TraitReference::trait_has_generics () const
   return !trait_substs.empty ();
 }
 
-std::vector<TyTy::SubstitutionParamMapping>
+std::vector<TyTy::SubstitutionParamMapping> &
+TraitReference::get_trait_substs ()
+{
+  return trait_substs;
+}
+
+const std::vector<TyTy::SubstitutionParamMapping> &
 TraitReference::get_trait_substs () const
 {
   return trait_substs;
index 8b1ac7daf7f1c91a24ae87fb2f7473e3a4c79b31..473513ea75ff1eb68fc96481da7b7108cd64385e 100644 (file)
@@ -224,7 +224,9 @@ public:
 
   bool trait_has_generics () const;
 
-  std::vector<TyTy::SubstitutionParamMapping> get_trait_substs () const;
+  std::vector<TyTy::SubstitutionParamMapping> &get_trait_substs ();
+
+  const std::vector<TyTy::SubstitutionParamMapping> &get_trait_substs () const;
 
   bool satisfies_bound (const TraitReference &reference) const;
 
index c1404561f4d084090cc5e61847c846ebf9b01ba3..5db0e5690c97d6c8332942cdc6d9a1219b35b2da 100644 (file)
@@ -17,6 +17,7 @@
 // <http://www.gnu.org/licenses/>.
 
 #include "optional.h"
+#include "rust-common.h"
 #include "rust-hir-expr.h"
 #include "rust-system.h"
 #include "rust-tyty-call.h"
@@ -59,6 +60,19 @@ TypeCheckExpr::Resolve (HIR::Expr &expr)
   return resolver.infered;
 }
 
+TyTy::BaseType *
+TypeCheckExpr::ResolveOpOverload (LangItem::Kind lang_item_type,
+                                 HIR::OperatorExprMeta expr,
+                                 TyTy::BaseType *lhs, TyTy::BaseType *rhs,
+                                 HIR::PathIdentSegment specified_segment)
+{
+  TypeCheckExpr resolver;
+
+  resolver.resolve_operator_overload (lang_item_type, expr, lhs, rhs,
+                                     specified_segment);
+  return resolver.infered;
+}
+
 void
 TypeCheckExpr::visit (HIR::TupleIndexExpr &expr)
 {
@@ -1885,7 +1899,16 @@ TypeCheckExpr::resolve_operator_overload (
   // probe for the lang-item
   if (!lang_item_defined)
     return false;
+
   DefId &respective_lang_item_id = lang_item_defined.value ();
+  auto def_lookup = mappings.lookup_defid (respective_lang_item_id);
+  rust_assert (def_lookup.has_value ());
+
+  HIR::Item *def_item = def_lookup.value ();
+  rust_assert (def_item->get_item_kind () == HIR::Item::ItemKind::Trait);
+  HIR::Trait &trait = *static_cast<HIR::Trait *> (def_item);
+  TraitReference *defid_trait_reference = TraitResolver::Resolve (trait);
+  rust_assert (!defid_trait_reference->is_error ());
 
   // we might be in a static or const context and unknown is fine
   TypeCheckContextItem current_context = TypeCheckContextItem::get_error ();
@@ -1929,15 +1952,49 @@ TypeCheckExpr::resolve_operator_overload (
 
   if (selected_candidates.size () > 1)
     {
-      // mutliple candidates
-      rich_location r (line_table, expr.get_locus ());
-      for (auto &c : resolved_candidates)
-       r.add_range (c.candidate.locus);
+      auto infer
+       = TyTy::TyVar::get_implicit_infer_var (expr.get_locus ()).get_tyty ();
+      auto trait_subst = defid_trait_reference->get_trait_substs ();
+      rust_assert (trait_subst.size () > 0);
 
-      rust_error_at (
-       r, "multiple candidates found for possible operator overload");
+      TyTy::TypeBoundPredicate pred (respective_lang_item_id, trait_subst,
+                                    BoundPolarity::RegularBound,
+                                    expr.get_locus ());
 
-      return false;
+      std::vector<TyTy::SubstitutionArg> mappings;
+      auto &self_param_mapping = trait_subst[0];
+      mappings.push_back (TyTy::SubstitutionArg (&self_param_mapping, lhs));
+
+      if (rhs != nullptr)
+       {
+         rust_assert (trait_subst.size () == 2);
+         auto &rhs_param_mapping = trait_subst[1];
+         mappings.push_back (TyTy::SubstitutionArg (&rhs_param_mapping, lhs));
+       }
+
+      std::map<std::string, TyTy::BaseType *> binding_args;
+      binding_args["Output"] = infer;
+
+      TyTy::SubstitutionArgumentMappings arg_mappings (mappings, binding_args,
+                                                      TyTy::RegionParamList (
+                                                        trait_subst.size ()),
+                                                      expr.get_locus ());
+      pred.apply_argument_mappings (arg_mappings, false);
+
+      infer->inherit_bounds ({pred});
+      DeferredOpOverload defer (expr.get_mappings ().get_hirid (),
+                               lang_item_type, specified_segment, pred, expr);
+      context->insert_deferred_operator_overload (std::move (defer));
+
+      if (rhs != nullptr)
+       lhs = unify_site (expr.get_mappings ().get_hirid (),
+                         TyTy::TyWithLocation (lhs),
+                         TyTy::TyWithLocation (rhs), expr.get_locus ());
+
+      infered = unify_site (expr.get_mappings ().get_hirid (),
+                           TyTy::TyWithLocation (lhs),
+                           TyTy::TyWithLocation (infer), expr.get_locus ());
+      return true;
     }
 
   // Get the adjusted self
index 5311974368532daf2f6d4a329533b4caa53beee2..48f28c7007959bd6e3ff0a61c7a6cc7c027e57e9 100644 (file)
@@ -31,6 +31,11 @@ class TypeCheckExpr : private TypeCheckBase, private HIR::HIRExpressionVisitor
 public:
   static TyTy::BaseType *Resolve (HIR::Expr &expr);
 
+  static TyTy::BaseType *
+  ResolveOpOverload (LangItem::Kind lang_item_type, HIR::OperatorExprMeta expr,
+                    TyTy::BaseType *lhs, TyTy::BaseType *rhs,
+                    HIR::PathIdentSegment specified_segment);
+
   void visit (HIR::TupleIndexExpr &expr) override;
   void visit (HIR::TupleExpr &expr) override;
   void visit (HIR::ReturnExpr &expr) override;
index 356c55803ed62e011e0d2b9642a417f7318ef032..80e403448359127203eb50f64f088db3f44f13a8 100644 (file)
@@ -20,6 +20,7 @@
 #define RUST_HIR_TYPE_CHECK
 
 #include "rust-hir-map.h"
+#include "rust-mapping-common.h"
 #include "rust-tyty.h"
 #include "rust-hir-trait-reference.h"
 #include "rust-stacked-contexts.h"
@@ -157,6 +158,39 @@ public:
   WARN_UNUSED_RESULT Lifetime next () { return Lifetime (interner_index++); }
 };
 
+struct DeferredOpOverload
+{
+  HirId expr_id;
+  LangItem::Kind lang_item_type;
+  HIR::PathIdentSegment specified_segment;
+  TyTy::TypeBoundPredicate predicate;
+  HIR::OperatorExprMeta op;
+
+  DeferredOpOverload (HirId expr_id, LangItem::Kind lang_item_type,
+                     HIR::PathIdentSegment specified_segment,
+                     TyTy::TypeBoundPredicate &predicate,
+                     HIR::OperatorExprMeta op)
+    : expr_id (expr_id), lang_item_type (lang_item_type),
+      specified_segment (specified_segment), predicate (predicate), op (op)
+  {}
+
+  DeferredOpOverload (const struct DeferredOpOverload &other)
+    : expr_id (other.expr_id), lang_item_type (other.lang_item_type),
+      specified_segment (other.specified_segment), predicate (other.predicate),
+      op (other.op)
+  {}
+
+  DeferredOpOverload &operator= (struct DeferredOpOverload const &other)
+  {
+    expr_id = other.expr_id;
+    lang_item_type = other.lang_item_type;
+    specified_segment = other.specified_segment;
+    op = other.op;
+
+    return *this;
+  }
+};
+
 class TypeCheckContext
 {
 public:
@@ -237,6 +271,13 @@ public:
   void insert_operator_overload (HirId id, TyTy::FnType *call_site);
   bool lookup_operator_overload (HirId id, TyTy::FnType **call);
 
+  void insert_deferred_operator_overload (DeferredOpOverload deferred);
+  bool lookup_deferred_operator_overload (HirId id,
+                                         DeferredOpOverload *deferred);
+
+  void iterate_deferred_operator_overloads (
+    std::function<bool (HirId, DeferredOpOverload &)> cb);
+
   void insert_unconstrained_check_marker (HirId id, bool status);
   bool have_checked_for_unconstrained (HirId id, bool *result);
 
@@ -271,6 +312,7 @@ private:
   TypeCheckContext ();
 
   bool compute_infer_var (HirId id, TyTy::BaseType *ty, bool emit_error);
+  bool compute_ambigious_op_overload (HirId id, DeferredOpOverload &op);
 
   std::map<NodeId, HirId> node_id_refs;
   std::map<HirId, TyTy::BaseType *> resolved;
@@ -308,6 +350,9 @@ private:
   std::set<HirId> querys_in_progress;
   std::set<DefId> trait_queries_in_progress;
 
+  // deferred operator overload
+  std::map<HirId, DeferredOpOverload> deferred_operator_overloads;
+
   // variance analysis
   TyTy::VarianceAnalysis::CrateCtx variance_analysis_ctx;
 
index 7b3584823e44cead719e6cf05fe3a6aa60dc74f0..83b17612d5e36d5aef7af775d3c0f05bca2df7c4 100644 (file)
@@ -18,6 +18,7 @@
 
 #include "rust-hir-type-check.h"
 #include "rust-type-util.h"
+#include "rust-hir-type-check-expr.h"
 
 namespace Rust {
 namespace Resolver {
@@ -408,6 +409,38 @@ TypeCheckContext::lookup_operator_overload (HirId id, TyTy::FnType **call)
   return true;
 }
 
+void
+TypeCheckContext::insert_deferred_operator_overload (
+  DeferredOpOverload deferred)
+{
+  HirId expr_id = deferred.expr_id;
+  deferred_operator_overloads.emplace (std::make_pair (expr_id, deferred));
+}
+
+bool
+TypeCheckContext::lookup_deferred_operator_overload (
+  HirId id, DeferredOpOverload *deferred)
+{
+  auto it = deferred_operator_overloads.find (id);
+  if (it == deferred_operator_overloads.end ())
+    return false;
+
+  *deferred = it->second;
+  return true;
+}
+
+void
+TypeCheckContext::iterate_deferred_operator_overloads (
+  std::function<bool (HirId, DeferredOpOverload &)> cb)
+{
+  for (auto it = deferred_operator_overloads.begin ();
+       it != deferred_operator_overloads.end (); it++)
+    {
+      if (!cb (it->first, it->second))
+       return;
+    }
+}
+
 void
 TypeCheckContext::insert_unconstrained_check_marker (HirId id, bool status)
 {
@@ -574,10 +607,38 @@ TypeCheckContext::regions_from_generic_args (const HIR::GenericArgs &args) const
   return regions;
 }
 
+bool
+TypeCheckContext::compute_ambigious_op_overload (HirId id,
+                                                DeferredOpOverload &op)
+{
+  rust_debug ("attempting resolution of op overload: %s",
+             op.predicate.as_string ().c_str ());
+
+  TyTy::BaseType *lhs = nullptr;
+  bool ok = lookup_type (op.op.get_lvalue_mappings ().get_hirid (), &lhs);
+  rust_assert (ok);
+
+  TyTy::BaseType *rhs = nullptr;
+  if (op.op.has_rvalue_mappings ())
+    {
+      bool ok = lookup_type (op.op.get_rvalue_mappings ().get_hirid (), &rhs);
+      rust_assert (ok);
+    }
+
+  TypeCheckExpr::ResolveOpOverload (op.lang_item_type, op.op, lhs, rhs,
+                                   op.specified_segment);
+
+  return true;
+}
+
 void
 TypeCheckContext::compute_inference_variables (bool emit_error)
 {
-  // default inference variables if possible
+  iterate_deferred_operator_overloads (
+    [&] (HirId id, DeferredOpOverload &op) mutable -> bool {
+      return compute_ambigious_op_overload (id, op);
+    });
+
   iterate ([&] (HirId id, TyTy::BaseType *ty) mutable -> bool {
     return compute_infer_var (id, ty, emit_error);
   });
diff --git a/gcc/testsuite/rust/compile/issue-3916.rs b/gcc/testsuite/rust/compile/issue-3916.rs
new file mode 100644 (file)
index 0000000..59b522b
--- /dev/null
@@ -0,0 +1,36 @@
+#![feature(rustc_attrs)]
+
+#[lang = "sized"]
+trait Sized {}
+
+#[lang = "add"]
+trait Add<Rhs = Self> {
+    type Output;
+
+    fn add(self, rhs: Rhs) -> Self::Output;
+}
+
+macro_rules! add_impl {
+    ($($t:ty)*) => ($(
+        impl Add for $t {
+            type Output = $t;
+
+            #[inline]
+            #[rustc_inherit_overflow_checks]
+            fn add(self, other: $t) -> $t { self + other }
+        }
+    )*)
+}
+
+add_impl! { usize u8 u16 u32 u64 u128 isize i8 i16 i32 i64 i128 f32 f64 }
+
+pub fn test(len: usize) -> u64 {
+    let mut i = 0;
+    let mut out = 0;
+    if i + 3 < len {
+        out = 123;
+    } else {
+        out = 456;
+    }
+    out
+}