]> git.ipfire.org Git - thirdparty/gcc.git/commitdiff
gccrs: Add method selection to operator overloading
authorPhilip Herron <herron.philip@googlemail.com>
Sun, 25 Jun 2023 20:21:36 +0000 (21:21 +0100)
committerArthur Cohen <arthur.cohen@embecosm.com>
Tue, 16 Jan 2024 17:46:29 +0000 (18:46 +0100)
When we do operator overloading we can get multiple candidates and we need
to use the optional arguments to filter the candidates event further. In
the bug we had:

  <integer> + <usize>

Without the Add impl blocks for the primitive interger types we unify
against the rhs to figure out that the lhs should be a usize but when we
are using the impl blocks we need to use the rhs to ensure that its
possible to coerce the rhs to the expected fntype parameter to filter the
candidates.

Fixes #2304

gcc/rust/ChangeLog:

* typecheck/rust-autoderef.cc: use new selection filter
* typecheck/rust-hir-dot-operator.cc (MethodResolver::Select): new slection Filter
* typecheck/rust-hir-dot-operator.h: New select prototype
* typecheck/rust-hir-type-check-expr.cc: call select
* typecheck/rust-type-util.cc (try_coercion): new helper
* typecheck/rust-type-util.h (try_coercion): helper prototype

gcc/testsuite/ChangeLog:

* rust/compile/issue-2304.rs: New test.

Signed-off-by: Philip Herron <herron.philip@googlemail.com>
gcc/rust/typecheck/rust-autoderef.cc
gcc/rust/typecheck/rust-hir-dot-operator.cc
gcc/rust/typecheck/rust-hir-dot-operator.h
gcc/rust/typecheck/rust-hir-type-check-expr.cc
gcc/rust/typecheck/rust-type-util.cc
gcc/rust/typecheck/rust-type-util.h
gcc/testsuite/rust/compile/issue-2304.rs [new file with mode: 0644]

index 8f5f6242aa2b887db0e2c3e28ea9eced1eee4624..9450cfaa068bae054bc8ff264031e19cdd2ce856 100644 (file)
@@ -170,18 +170,20 @@ resolve_operator_overload_fn (
        }
     }
 
-  bool have_implementation_for_lang_item = resolved_candidates.size () > 0;
+  auto selected_candidates
+    = MethodResolver::Select (resolved_candidates, lhs, {});
+  bool have_implementation_for_lang_item = selected_candidates.size () > 0;
   if (!have_implementation_for_lang_item)
     return false;
 
-  if (resolved_candidates.size () > 1)
+  if (selected_candidates.size () > 1)
     {
       // no need to error out as we are just trying to see if there is a fit
       return false;
     }
 
   // Get the adjusted self
-  MethodCandidate candidate = *resolved_candidates.begin ();
+  MethodCandidate candidate = *selected_candidates.begin ();
   Adjuster adj (lhs);
   TyTy::BaseType *adjusted_self = adj.adjust_type (candidate.adjustments);
 
index 28bc0c0c9a5e848e3d3b0c1b71c21a95a4e4a291..af4b87ef2cc74ae7306afe6229bdf5cb4ddadcee 100644 (file)
@@ -41,6 +41,45 @@ MethodResolver::Probe (TyTy::BaseType *receiver,
   return resolver.result;
 }
 
+std::set<MethodCandidate>
+MethodResolver::Select (std::set<MethodCandidate> &candidates,
+                       TyTy::BaseType *receiver,
+                       std::vector<TyTy::BaseType *> arguments)
+{
+  std::set<MethodCandidate> selected;
+  for (auto &candidate : candidates)
+    {
+      TyTy::BaseType *candidate_type = candidate.candidate.ty;
+      rust_assert (candidate_type->get_kind () == TyTy::TypeKind::FNDEF);
+      TyTy::FnType &fn = *static_cast<TyTy::FnType *> (candidate_type);
+
+      // match the number of arguments
+      if (fn.num_params () != (arguments.size () + 1))
+       continue;
+
+      // match the arguments
+      bool failed = false;
+      for (size_t i = 0; i < arguments.size (); i++)
+       {
+         TyTy::BaseType *arg = arguments.at (i);
+         TyTy::BaseType *param = fn.get_params ().at (i + 1).second;
+         TyTy::BaseType *coerced
+           = try_coercion (0, TyTy::TyWithLocation (param),
+                           TyTy::TyWithLocation (arg), Location ());
+         if (coerced->get_kind () == TyTy::TypeKind::ERROR)
+           {
+             failed = true;
+             break;
+           }
+       }
+
+      if (!failed)
+       selected.insert (candidate);
+    }
+
+  return selected;
+}
+
 void
 MethodResolver::try_hook (const TyTy::BaseType &r)
 {
index aa31d7a118e926ae369ac84b3277965c5b342eea..804b18fbff24a20019f05c5a67b3a73c709ac116 100644 (file)
@@ -57,6 +57,10 @@ public:
   Probe (TyTy::BaseType *receiver, const HIR::PathIdentSegment &segment_name,
         bool autoderef_flag = false);
 
+  static std::set<MethodCandidate>
+  Select (std::set<MethodCandidate> &candidates, TyTy::BaseType *receiver,
+         std::vector<TyTy::BaseType *> arguments);
+
   static std::vector<predicate_candidate> get_predicate_items (
     const HIR::PathIdentSegment &segment_name, const TyTy::BaseType &receiver,
     const std::vector<TyTy::TypeBoundPredicate> &specified_bounds);
index 734c4b49a5f2b685a04db97de436e93097b73760..8affe983ba853d07887b9f0e6e6eb5e7ed65332f 100644 (file)
@@ -1619,11 +1619,17 @@ TypeCheckExpr::resolve_operator_overload (
        }
     }
 
-  bool have_implementation_for_lang_item = resolved_candidates.size () > 0;
+  std::vector<TyTy::BaseType *> select_args = {};
+  if (rhs != nullptr)
+    select_args = {rhs};
+  auto selected_candidates
+    = MethodResolver::Select (resolved_candidates, lhs, select_args);
+
+  bool have_implementation_for_lang_item = selected_candidates.size () > 0;
   if (!have_implementation_for_lang_item)
     return false;
 
-  if (resolved_candidates.size () > 1)
+  if (selected_candidates.size () > 1)
     {
       // mutliple candidates
       RichLocation r (expr.get_locus ());
@@ -1637,7 +1643,7 @@ TypeCheckExpr::resolve_operator_overload (
     }
 
   // Get the adjusted self
-  MethodCandidate candidate = *resolved_candidates.begin ();
+  MethodCandidate candidate = *selected_candidates.begin ();
   Adjuster adj (lhs);
   TyTy::BaseType *adjusted_self = adj.adjust_type (candidate.adjustments);
 
index 8e8871cab4deae622a05c5b341fc0f324d12ffbe..561509cda005dcf619df66c5876b643660b46eb0 100644 (file)
@@ -231,6 +231,24 @@ coercion_site (HirId id, TyTy::TyWithLocation lhs, TyTy::TyWithLocation rhs,
   return coerced;
 }
 
+TyTy::BaseType *
+try_coercion (HirId id, TyTy::TyWithLocation lhs, TyTy::TyWithLocation rhs,
+             Location locus)
+{
+  TyTy::BaseType *expected = lhs.get_ty ();
+  TyTy::BaseType *expr = rhs.get_ty ();
+
+  rust_debug ("try_coercion_site id={%u} expected={%s} expr={%s}", id,
+             expected->debug_str ().c_str (), expr->debug_str ().c_str ());
+
+  auto result = TypeCoercionRules::TryCoerce (expr, expected, locus,
+                                             true /*allow-autodref*/);
+  if (result.is_error ())
+    return new TyTy::ErrorType (id);
+
+  return result.tyty;
+}
+
 TyTy::BaseType *
 cast_site (HirId id, TyTy::TyWithLocation from, TyTy::TyWithLocation to,
           Location cast_locus)
index 938f410296a6f98390d278c78173533facc63299..595388ef34f747d1be9a936184d32421b53fa48f 100644 (file)
@@ -45,6 +45,10 @@ TyTy::BaseType *
 coercion_site (HirId id, TyTy::TyWithLocation lhs, TyTy::TyWithLocation rhs,
               Location coercion_locus);
 
+TyTy::BaseType *
+try_coercion (HirId id, TyTy::TyWithLocation lhs, TyTy::TyWithLocation rhs,
+             Location coercion_locus);
+
 TyTy::BaseType *
 cast_site (HirId id, TyTy::TyWithLocation from, TyTy::TyWithLocation to,
           Location cast_locus);
diff --git a/gcc/testsuite/rust/compile/issue-2304.rs b/gcc/testsuite/rust/compile/issue-2304.rs
new file mode 100644 (file)
index 0000000..243cf10
--- /dev/null
@@ -0,0 +1,23 @@
+#[lang = "add"]
+pub trait Add<RHS = Self> {
+    type Output;
+
+    fn add(self, rhs: RHS) -> Self::Output;
+}
+macro_rules! add_impl {
+    ($($t:ty)*) => ($(
+        impl Add for $t {
+            type Output = $t;
+
+            fn add(self, other: $t) -> $t { self + other }
+        }
+    )*)
+}
+
+add_impl! { usize u8 u16 u32 u64  /*isize i8 i16 i32 i64*/  f32 f64 }
+
+pub fn test() {
+    let x: usize = 123;
+    let mut i = 0;
+    let _bug = i + x;
+}