]> git.ipfire.org Git - thirdparty/gcc.git/commitdiff
gccrs: Improve operator overload check for recursive overload
authorPhilip Herron <herron.philip@googlemail.com>
Mon, 17 Apr 2023 21:17:37 +0000 (22:17 +0100)
committerArthur Cohen <arthur.cohen@embecosm.com>
Tue, 16 Jan 2024 17:34:15 +0000 (18:34 +0100)
This is a case in #2019 where you have the default Add impl for i32 for
example which has:

  impl Add for i32 {
    fn add(self, other:i32) { self + other }
  }

This function will do a check for operator overload which then is able to
find multiple candidates such as:

  impl<'a> Add<i32> for &'a i32 {
    type Output = <i32 as Add<i32>>::Output;

    fn add(self, other: i32) -> <i32 as Add<i32>>::Output {
        Add::add(*self, other)
    }
  }

This initial operator overload will resolve to this as it looks like a
valid candidate. This patch adds another level of checks to ensure the
candidate does not equal the current context DefId.

Addresses #2019

gcc/rust/ChangeLog:

* typecheck/rust-hir-type-check-expr.cc: update for op overload

Signed-off-by: Philip Herron <herron.philip@googlemail.com>
gcc/rust/typecheck/rust-hir-type-check-expr.cc

index 2eec75db3ca8e328a7c0b3a20dea20a4dcaa9171..5eddada3f6293f30b0929da0494fde09078254fe 100644 (file)
@@ -1589,18 +1589,43 @@ TypeCheckExpr::resolve_operator_overload (
   if (!lang_item_defined)
     return false;
 
+  // we might be in a static or const context and unknown is fine
+  TypeCheckContextItem current_context = TypeCheckContextItem::get_error ();
+  if (context->have_function_context ())
+    {
+      current_context = context->peek_context ();
+    }
+
   auto segment = HIR::PathIdentSegment (associated_item_name);
   auto candidates = MethodResolver::Probe (lhs, segment);
 
-  bool have_implementation_for_lang_item = candidates.size () > 0;
+  // remove any recursive candidates
+  std::set<MethodCandidate> resolved_candidates;
+  for (auto &c : candidates)
+    {
+      const TyTy::BaseType *candidate_type = c.candidate.ty;
+      rust_assert (candidate_type->get_kind () == TyTy::TypeKind::FNDEF);
+
+      const TyTy::FnType &fn
+       = *static_cast<const TyTy::FnType *> (candidate_type);
+
+      DefId current_fn_defid = current_context.get_defid ();
+      bool recursive_candidated = fn.get_id () == current_fn_defid;
+      if (!recursive_candidated)
+       {
+         resolved_candidates.insert (c);
+       }
+    }
+
+  bool have_implementation_for_lang_item = resolved_candidates.size () > 0;
   if (!have_implementation_for_lang_item)
     return false;
 
-  if (candidates.size () > 1)
+  if (resolved_candidates.size () > 1)
     {
       // mutliple candidates
       RichLocation r (expr.get_locus ());
-      for (auto &c : candidates)
+      for (auto &c : resolved_candidates)
        r.add_range (c.candidate.locus);
 
       rust_error_at (
@@ -1610,18 +1635,40 @@ TypeCheckExpr::resolve_operator_overload (
     }
 
   // Get the adjusted self
-  auto candidate = *candidates.begin ();
+  MethodCandidate candidate = *resolved_candidates.begin ();
   Adjuster adj (lhs);
   TyTy::BaseType *adjusted_self = adj.adjust_type (candidate.adjustments);
 
-  // is this the case we are recursive
-  // handle the case where we are within the impl block for this lang_item
-  // otherwise we end up with a recursive operator overload such as the i32
-  // operator overload trait
-  TypeCheckContextItem fn_context = context->peek_context ();
-  if (fn_context.get_type () == TypeCheckContextItem::ItemType::IMPL_ITEM)
+  // store the adjustments for code-generation to know what to do
+  context->insert_autoderef_mappings (expr.get_lvalue_mappings ().get_hirid (),
+                                     std::move (candidate.adjustments));
+
+  // now its just like a method-call-expr
+  context->insert_receiver (expr.get_mappings ().get_hirid (), lhs);
+
+  PathProbeCandidate &resolved_candidate = candidate.candidate;
+  TyTy::BaseType *lookup_tyty = candidate.candidate.ty;
+  NodeId resolved_node_id
+    = resolved_candidate.is_impl_candidate ()
+       ? resolved_candidate.item.impl.impl_item->get_impl_mappings ()
+           .get_nodeid ()
+       : resolved_candidate.item.trait.item_ref->get_mappings ().get_nodeid ();
+
+  rust_assert (lookup_tyty->get_kind () == TyTy::TypeKind::FNDEF);
+  TyTy::BaseType *lookup = lookup_tyty;
+  TyTy::FnType *fn = static_cast<TyTy::FnType *> (lookup);
+  rust_assert (fn->is_method ());
+
+  rust_debug ("is_impl_item_candidate: %s",
+             resolved_candidate.is_impl_candidate () ? "true" : "false");
+
+  // in the case where we resolve to a trait bound we have to be careful we are
+  // able to do so there is a case where we are currently resolving the deref
+  // operator overload function which is generic and this might resolve to the
+  // trait item of deref which is not valid as its just another recursive case
+  if (current_context.get_type () == TypeCheckContextItem::ItemType::IMPL_ITEM)
     {
-      auto &impl_item = fn_context.get_impl_item ();
+      auto &impl_item = current_context.get_impl_item ();
       HIR::ImplBlock *parent = impl_item.first;
       HIR::Function *fn = impl_item.second;
 
@@ -1655,26 +1702,7 @@ TypeCheckExpr::resolve_operator_overload (
        }
     }
 
-  // store the adjustments for code-generation to know what to do
-  context->insert_autoderef_mappings (expr.get_lvalue_mappings ().get_hirid (),
-                                     std::move (candidate.adjustments));
-
-  // now its just like a method-call-expr
-  context->insert_receiver (expr.get_mappings ().get_hirid (), lhs);
-
-  PathProbeCandidate &resolved_candidate = candidate.candidate;
-  TyTy::BaseType *lookup_tyty = candidate.candidate.ty;
-  NodeId resolved_node_id
-    = resolved_candidate.is_impl_candidate ()
-       ? resolved_candidate.item.impl.impl_item->get_impl_mappings ()
-           .get_nodeid ()
-       : resolved_candidate.item.trait.item_ref->get_mappings ().get_nodeid ();
-
-  rust_assert (lookup_tyty->get_kind () == TyTy::TypeKind::FNDEF);
-  TyTy::BaseType *lookup = lookup_tyty;
-  TyTy::FnType *fn = static_cast<TyTy::FnType *> (lookup);
-  rust_assert (fn->is_method ());
-
+  // we found a valid operator overload
   fn->prepare_higher_ranked_bounds ();
   rust_debug_loc (expr.get_locus (), "resolved operator overload to: {%u} {%s}",
                  candidate.candidate.ty->get_ref (),