]> git.ipfire.org Git - thirdparty/gcc.git/commitdiff
Support associated type bound arguments
authorPhilip Herron <herron.philip@googlemail.com>
Sat, 7 Jan 2023 14:41:12 +0000 (14:41 +0000)
committerPhilip Herron <herron.philip@googlemail.com>
Tue, 10 Jan 2023 22:12:16 +0000 (22:12 +0000)
This patch adds support for the GenercArgsBinding type, where you can
specify the associated types of a trait bound using `<Foo=i32>` style
syntax. Note that the type-resolution relys on the i32 impl for Add
as type resolution will resolve the `a+a` to the core::ops::Add method
so code generation will require this to exist.

I have ameded testsuite/rust/compile/bounds.rs as this code is wrongly
creating an HIR::GenericArgs with a trait-object type and causing issues.
the parsing is still correct but we dont have the mechanism to represent
this in AST and HIR properly. I think we will need a new HIR::GenericArgs
AssociatedTypeBindingBound or something similar. We are still lacking
bounds checking during are type coercions and unifications so running this
example using an f32 will wrongly pass type checking, this will need
addressed next.

Fixes #1720

Signed-off-by: Philip Herron <herron.philip@googlemail.com>
gcc/rust/ChangeLog:

* hir/tree/rust-hir-path.h:
* typecheck/rust-hir-path-probe.h:
* typecheck/rust-hir-trait-resolve.cc:
* typecheck/rust-hir-type-check-expr.cc (TypeCheckExpr::visit):
* typecheck/rust-tyty-bounds.cc (TypeCheckBase::get_predicate_from_bound):
(TypeBoundPredicate::TypeBoundPredicate):
(TypeBoundPredicate::operator=):
(TypeBoundPredicate::apply_generic_arguments):
(TypeBoundPredicateItem::get_tyty_for_receiver):
(TypeBoundPredicate::get_num_associated_bindings):
(TypeBoundPredicate::lookup_associated_type):
(TypeBoundPredicate::get_associated_type_items):
* typecheck/rust-tyty.cc (SubstitutionRef::get_mappings_from_generic_args):
(SubstitutionRef::infer_substitions):
(SubstitutionRef::solve_missing_mappings_from_this):
* typecheck/rust-tyty.h:

gcc/testsuite/ChangeLog:

* rust/compile/bounds.rs: change to use -fsyntax-only
* rust/execute/torture/issue-1720.rs: New test.

gcc/rust/hir/tree/rust-hir-path.h
gcc/rust/typecheck/rust-hir-path-probe.h
gcc/rust/typecheck/rust-hir-trait-resolve.cc
gcc/rust/typecheck/rust-hir-type-check-expr.cc
gcc/rust/typecheck/rust-tyty-bounds.cc
gcc/rust/typecheck/rust-tyty.cc
gcc/rust/typecheck/rust-tyty.h
gcc/testsuite/rust/compile/bounds.rs
gcc/testsuite/rust/execute/torture/issue-1720.rs [new file with mode: 0644]

index 8cc16f367cf38e07f916f01c006b23eedf0f984c..88eb7a7addd55d1c25705383c6db6f6d050a61f4 100644 (file)
@@ -105,9 +105,11 @@ public:
 
   std::string as_string () const;
 
-  Identifier get_identifier () const { return identifier; }
+  Identifier &get_identifier () { return identifier; }
+  const Identifier &get_identifier () const { return identifier; }
 
   std::unique_ptr<Type> &get_type () { return type; }
+  const std::unique_ptr<Type> &get_type () const { return type; }
 
   Location get_locus () const { return locus; }
 };
index c8f207cf008820284b3d9be1d7a8a8f25f42431e..ac7d4f5f98b34abaf6955d306331a2c33b73a1ed 100644 (file)
@@ -362,7 +362,8 @@ protected:
        mappings.push_back (TyTy::SubstitutionArg (param, receiver->clone ()));
 
        Location locus; // FIXME
-       TyTy::SubstitutionArgumentMappings args (std::move (mappings), locus);
+       TyTy::SubstitutionArgumentMappings args (std::move (mappings), {},
+                                                locus);
        trait_item_tyty = SubstMapperInternal::Resolve (trait_item_tyty, args);
       }
 
index 6bd3cc1a9e948f59767f95135bf4a23608c6124a..d968c2078f4da92453b56480bec2e081a4efe848 100644 (file)
@@ -441,8 +441,8 @@ AssociatedImplTrait::setup_associated_types (
        param_mappings[p.get_symbol ()] = a.get_tyty ()->get_ref ();
       };
 
-  TyTy::SubstitutionArgumentMappings infer_arguments (std::move (args), locus,
-                                                     param_subst_cb);
+  TyTy::SubstitutionArgumentMappings infer_arguments (std::move (args), {},
+                                                     locus, param_subst_cb);
   TyTy::BaseType *impl_self_infer
     = (associated_self->needs_generic_substitutions ())
        ? SubstMapperInternal::Resolve (associated_self, infer_arguments)
@@ -489,8 +489,9 @@ AssociatedImplTrait::setup_associated_types (
       hrtb_bound_arguments.push_back (r);
     }
 
-  rust_assert (impl_trait_predicate_args.size ()
-              == hrtb_bound_arguments.size ());
+  if (impl_trait_predicate_args.size () != hrtb_bound_arguments.size ())
+    return;
+
   for (size_t i = 0; i < impl_trait_predicate_args.size (); i++)
     {
       TyTy::BaseType *a = impl_trait_predicate_args.at (i);
@@ -521,7 +522,7 @@ AssociatedImplTrait::setup_associated_types (
     }
 
   TyTy::SubstitutionArgumentMappings associated_type_args (
-    std::move (associated_arguments), locus);
+    std::move (associated_arguments), {}, locus);
 
   ImplTypeIterator iter (*impl, [&] (HIR::TypeAlias &type) {
     TraitItemReference *resolved_trait_item = nullptr;
index 3ee59e6ca7c6f767d48574e64e3f413fcdb3bbbc..9bd4d09981402e74ce71c9d51f3387d063e4aa28 100644 (file)
@@ -618,7 +618,8 @@ TypeCheckExpr::visit (HIR::RangeFromToExpr &expr)
   const TyTy::SubstitutionParamMapping *param_ref = &adt->get_substs ().at (0);
   subst_mappings.push_back (TyTy::SubstitutionArg (param_ref, unified));
 
-  TyTy::SubstitutionArgumentMappings subst (subst_mappings, expr.get_locus ());
+  TyTy::SubstitutionArgumentMappings subst (subst_mappings, {},
+                                           expr.get_locus ());
   infered = SubstMapperInternal::Resolve (adt, subst);
 }
 
@@ -664,7 +665,8 @@ TypeCheckExpr::visit (HIR::RangeFromExpr &expr)
   const TyTy::SubstitutionParamMapping *param_ref = &adt->get_substs ().at (0);
   subst_mappings.push_back (TyTy::SubstitutionArg (param_ref, from_ty));
 
-  TyTy::SubstitutionArgumentMappings subst (subst_mappings, expr.get_locus ());
+  TyTy::SubstitutionArgumentMappings subst (subst_mappings, {},
+                                           expr.get_locus ());
   infered = SubstMapperInternal::Resolve (adt, subst);
 }
 
@@ -709,7 +711,8 @@ TypeCheckExpr::visit (HIR::RangeToExpr &expr)
   const TyTy::SubstitutionParamMapping *param_ref = &adt->get_substs ().at (0);
   subst_mappings.push_back (TyTy::SubstitutionArg (param_ref, from_ty));
 
-  TyTy::SubstitutionArgumentMappings subst (subst_mappings, expr.get_locus ());
+  TyTy::SubstitutionArgumentMappings subst (subst_mappings, {},
+                                           expr.get_locus ());
   infered = SubstMapperInternal::Resolve (adt, subst);
 }
 
@@ -792,7 +795,8 @@ TypeCheckExpr::visit (HIR::RangeFromToInclExpr &expr)
   const TyTy::SubstitutionParamMapping *param_ref = &adt->get_substs ().at (0);
   subst_mappings.push_back (TyTy::SubstitutionArg (param_ref, unified));
 
-  TyTy::SubstitutionArgumentMappings subst (subst_mappings, expr.get_locus ());
+  TyTy::SubstitutionArgumentMappings subst (subst_mappings, {},
+                                           expr.get_locus ());
   infered = SubstMapperInternal::Resolve (adt, subst);
 }
 
index 53eccb79d93b9ba3b4093154c307f179bb283a42..e5057c8e3c0cfd73d160d47676360c5605429f4c 100644 (file)
@@ -145,11 +145,10 @@ TypeCheckBase::get_predicate_from_bound (HIR::TypePath &type_path)
       break;
     }
 
-  // FIXME
-  // I think this should really be just be if the !args.is_empty() because
-  // someone might wrongly apply generic arguments where they should not and
-  // they will be missing error diagnostics
-  if (predicate.requires_generic_args ())
+  // we try to apply generic arguments when they are non empty and or when the
+  // predicate requires them so that we get the relevant Foo expects x number
+  // arguments but got zero see test case rust/compile/traits12.rs
+  if (!args.is_empty () || predicate.requires_generic_args ())
     {
       // this is applying generic arguments to a trait reference
       predicate.apply_generic_arguments (&args);
@@ -222,7 +221,7 @@ TypeBoundPredicate::TypeBoundPredicate (const TypeBoundPredicate &other)
     }
 
   used_arguments
-    = SubstitutionArgumentMappings (copied_arg_mappings,
+    = SubstitutionArgumentMappings (copied_arg_mappings, {},
                                    other.used_arguments.get_locus ());
 }
 
@@ -258,7 +257,7 @@ TypeBoundPredicate::operator= (const TypeBoundPredicate &other)
     }
 
   used_arguments
-    = SubstitutionArgumentMappings (copied_arg_mappings,
+    = SubstitutionArgumentMappings (copied_arg_mappings, {},
                                    other.used_arguments.get_locus ());
 
   return *this;
@@ -331,6 +330,19 @@ TypeBoundPredicate::apply_generic_arguments (HIR::GenericArgs *generic_args)
       if (ok && arg.get_tyty () != nullptr)
        sub.fill_param_ty (subst_mappings, subst_mappings.get_locus ());
     }
+
+  // associated argument mappings
+  for (auto &it : subst_mappings.get_binding_args ())
+    {
+      std::string identifier = it.first;
+      TyTy::BaseType *type = it.second;
+
+      TypeBoundPredicateItem item = lookup_associated_item (identifier);
+      rust_assert (!item.is_error ());
+
+      const auto item_ref = item.get_raw_item ();
+      item_ref->associated_type_set (type);
+    }
 }
 
 bool
@@ -389,7 +401,8 @@ TypeBoundPredicateItem::get_tyty_for_receiver (const TyTy::BaseType *receiver)
       adjusted_mappings.push_back (std::move (arg));
     }
 
-  SubstitutionArgumentMappings adjusted (adjusted_mappings, gargs.get_locus (),
+  SubstitutionArgumentMappings adjusted (adjusted_mappings, {},
+                                        gargs.get_locus (),
                                         gargs.get_subst_cb (),
                                         true /* trait-mode-flag */);
   return Resolver::SubstMapperInternal::Resolve (trait_item_tyty, adjusted);
@@ -421,6 +434,19 @@ TypeBoundPredicate::handle_substitions (
       p->set_ty_ref (s->get_ty_ref ());
     }
 
+  // associated argument mappings
+  for (auto &it : subst_mappings.get_binding_args ())
+    {
+      std::string identifier = it.first;
+      TyTy::BaseType *type = it.second;
+
+      TypeBoundPredicateItem item = lookup_associated_item (identifier);
+      rust_assert (!item.is_error ());
+
+      const auto item_ref = item.get_raw_item ();
+      item_ref->associated_type_set (type);
+    }
+
   // FIXME more error handling at some point
   // used_arguments = subst_mappings;
   // error_flag |= used_arguments.is_error ();
@@ -440,6 +466,13 @@ TypeBoundPredicate::requires_generic_args () const
 bool
 TypeBoundPredicate::contains_associated_types () const
 {
+  return get_num_associated_bindings () > 0;
+}
+
+size_t
+TypeBoundPredicate::get_num_associated_bindings () const
+{
+  size_t count = 0;
   auto trait_ref = get ();
   for (const auto &trait_item : trait_ref->get_trait_items ())
     {
@@ -447,9 +480,45 @@ TypeBoundPredicate::contains_associated_types () const
        = trait_item.get_trait_item_type ()
          == Resolver::TraitItemReference::TraitItemType::TYPE;
       if (is_associated_type)
-       return true;
+       count++;
+    }
+  return count;
+}
+
+TypeBoundPredicateItem
+TypeBoundPredicate::lookup_associated_type (const std::string &search)
+{
+  TypeBoundPredicateItem item = lookup_associated_item (search);
+
+  // only need to check that it is infact an associated type because other wise
+  // if it was not found it will just be an error node anyway
+  if (!item.is_error ())
+    {
+      const auto raw = item.get_raw_item ();
+      if (raw->get_trait_item_type ()
+         != Resolver::TraitItemReference::TraitItemType::TYPE)
+       return TypeBoundPredicateItem::error ();
+    }
+  return item;
+}
+
+std::vector<TypeBoundPredicateItem>
+TypeBoundPredicate::get_associated_type_items ()
+{
+  std::vector<TypeBoundPredicateItem> items;
+  auto trait_ref = get ();
+  for (const auto &trait_item : trait_ref->get_trait_items ())
+    {
+      bool is_associated_type
+       = trait_item.get_trait_item_type ()
+         == Resolver::TraitItemReference::TraitItemType::TYPE;
+      if (is_associated_type)
+       {
+         TypeBoundPredicateItem item (this, &trait_item);
+         items.push_back (std::move (item));
+       }
     }
-  return false;
+  return items;
 }
 
 // trait item reference
index 86f40af0fbe56472964d92e641826ef752e2cb2b..9358919ffaf4555d951f730272d1099e64017dec 100644 (file)
@@ -674,14 +674,58 @@ SubstitutionParamMapping::override_context ()
 SubstitutionArgumentMappings
 SubstitutionRef::get_mappings_from_generic_args (HIR::GenericArgs &args)
 {
+  std::map<std::string, BaseType *> binding_arguments;
   if (args.get_binding_args ().size () > 0)
     {
-      RichLocation r (args.get_locus ());
-      for (auto &binding : args.get_binding_args ())
-       r.add_range (binding.get_locus ());
+      if (supports_associated_bindings ())
+       {
+         if (args.get_binding_args ().size () > get_num_associated_bindings ())
+           {
+             RichLocation r (args.get_locus ());
+
+             rust_error_at (r,
+                            "generic item takes at most %lu type binding "
+                            "arguments but %lu were supplied",
+                            (unsigned long) get_num_associated_bindings (),
+                            (unsigned long) args.get_binding_args ().size ());
+             return SubstitutionArgumentMappings::error ();
+           }
 
-      rust_error_at (r, "associated type bindings are not allowed here");
-      return SubstitutionArgumentMappings::error ();
+         for (auto &binding : args.get_binding_args ())
+           {
+             BaseType *resolved
+               = Resolver::TypeCheckType::Resolve (binding.get_type ().get ());
+             if (resolved == nullptr
+                 || resolved->get_kind () == TyTy::TypeKind::ERROR)
+               {
+                 rust_error_at (binding.get_locus (),
+                                "failed to resolve type arguments");
+                 return SubstitutionArgumentMappings::error ();
+               }
+
+             // resolve to relevant binding
+             auto binding_item
+               = lookup_associated_type (binding.get_identifier ());
+             if (binding_item.is_error ())
+               {
+                 rust_error_at (binding.get_locus (),
+                                "unknown associated type binding: %s",
+                                binding.get_identifier ().c_str ());
+                 return SubstitutionArgumentMappings::error ();
+               }
+
+             binding_arguments[binding.get_identifier ()] = resolved;
+           }
+       }
+      else
+       {
+         RichLocation r (args.get_locus ());
+         for (auto &binding : args.get_binding_args ())
+           r.add_range (binding.get_locus ());
+
+         rust_error_at (r, "associated type bindings are not allowed here");
+         return SubstitutionArgumentMappings::error ();
+       }
     }
 
   // for inherited arguments
@@ -745,6 +789,7 @@ SubstitutionRef::get_mappings_from_generic_args (HIR::GenericArgs &args)
          if (resolved->contains_type_parameters ())
            {
              SubstitutionArgumentMappings intermediate (mappings,
+                                                        binding_arguments,
                                                         args.get_locus ());
              resolved = Resolver::SubstMapperInternal::Resolve (resolved,
                                                                 intermediate);
@@ -758,7 +803,8 @@ SubstitutionRef::get_mappings_from_generic_args (HIR::GenericArgs &args)
        }
     }
 
-  return SubstitutionArgumentMappings (mappings, args.get_locus ());
+  return SubstitutionArgumentMappings (mappings, binding_arguments,
+                                      args.get_locus ());
 }
 
 BaseType *
@@ -791,7 +837,13 @@ SubstitutionRef::infer_substitions (Location locus)
        }
     }
 
-  SubstitutionArgumentMappings infer_arguments (std::move (args), locus);
+  // FIXME do we need to add inference variables to all the possible bindings?
+  // it might just lead to inference variable hell not 100% sure if rustc does
+  // this i think the language might needs this to be explicitly set
+
+  SubstitutionArgumentMappings infer_arguments (std::move (args),
+                                               {} /* binding_arguments */,
+                                               locus);
   return handle_substitions (std::move (infer_arguments));
 }
 
@@ -835,7 +887,9 @@ SubstitutionRef::adjust_mappings_for_this (
   if (resolved_mappings.empty ())
     return SubstitutionArgumentMappings::error ();
 
-  return SubstitutionArgumentMappings (resolved_mappings, mappings.get_locus (),
+  return SubstitutionArgumentMappings (resolved_mappings,
+                                      mappings.get_binding_args (),
+                                      mappings.get_locus (),
                                       mappings.get_subst_cb (),
                                       mappings.trait_item_mode ());
 }
@@ -901,6 +955,7 @@ SubstitutionRef::solve_mappings_from_receiver_for_self (
     }
 
   return SubstitutionArgumentMappings (resolved_mappings,
+                                      mappings.get_binding_args (),
                                       mappings.get_locus ());
 }
 
@@ -952,7 +1007,7 @@ SubstitutionRef::solve_missing_mappings_from_this (SubstitutionRef &ref,
       resolved_mappings.push_back (std::move (argument));
     }
 
-  return SubstitutionArgumentMappings (resolved_mappings, locus);
+  return SubstitutionArgumentMappings (resolved_mappings, {}, locus);
 }
 
 bool
index 0503528c423498ac6239075041a01e348b001cea..4f333a8ed2f73bd1075d5f120b4ef65629703b6d 100644 (file)
@@ -675,16 +675,17 @@ class SubstitutionArgumentMappings
 {
 public:
   SubstitutionArgumentMappings (std::vector<SubstitutionArg> mappings,
+                               std::map<std::string, BaseType *> binding_args,
                                Location locus,
                                ParamSubstCb param_subst_cb = nullptr,
                                bool trait_item_flag = false)
-    : mappings (mappings), locus (locus), param_subst_cb (param_subst_cb),
-      trait_item_flag (trait_item_flag)
+    : mappings (mappings), binding_args (binding_args), locus (locus),
+      param_subst_cb (param_subst_cb), trait_item_flag (trait_item_flag)
   {}
 
   SubstitutionArgumentMappings (const SubstitutionArgumentMappings &other)
-    : mappings (other.mappings), locus (other.locus),
-      param_subst_cb (other.param_subst_cb),
+    : mappings (other.mappings), binding_args (other.binding_args),
+      locus (other.locus), param_subst_cb (other.param_subst_cb),
       trait_item_flag (other.trait_item_flag)
   {}
 
@@ -692,6 +693,7 @@ public:
   operator= (const SubstitutionArgumentMappings &other)
   {
     mappings = other.mappings;
+    binding_args = other.binding_args;
     locus = other.locus;
     param_subst_cb = other.param_subst_cb;
     trait_item_flag = other.trait_item_flag;
@@ -705,7 +707,7 @@ public:
 
   static SubstitutionArgumentMappings error ()
   {
-    return SubstitutionArgumentMappings ({}, Location (), nullptr, false);
+    return SubstitutionArgumentMappings ({}, {}, Location (), nullptr, false);
   }
 
   bool is_error () const { return mappings.size () == 0; }
@@ -759,6 +761,16 @@ public:
 
   const std::vector<SubstitutionArg> &get_mappings () const { return mappings; }
 
+  std::map<std::string, BaseType *> &get_binding_args ()
+  {
+    return binding_args;
+  }
+
+  const std::map<std::string, BaseType *> &get_binding_args () const
+  {
+    return binding_args;
+  }
+
   std::string as_string () const
   {
     std::string buffer;
@@ -783,6 +795,7 @@ public:
 
 private:
   std::vector<SubstitutionArg> mappings;
+  std::map<std::string, BaseType *> binding_args;
   Location locus;
   ParamSubstCb param_subst_cb;
   bool trait_item_flag;
@@ -813,6 +826,24 @@ public:
     return buffer.empty () ? "" : "<" + buffer + ">";
   }
 
+  bool supports_associated_bindings () const
+  {
+    return get_num_associated_bindings () > 0;
+  }
+
+  // this is overridden in TypeBoundPredicate
+  // which support bindings we don't add them directly to the SubstitutionRef
+  // base class because this class represents the fn<X: Foo, Y: Bar>. The only
+  // construct which supports associated types
+  virtual size_t get_num_associated_bindings () const { return 0; }
+
+  // this is overridden in TypeBoundPredicate
+  virtual TypeBoundPredicateItem
+  lookup_associated_type (const std::string &search)
+  {
+    return TypeBoundPredicateItem::error ();
+  }
+
   size_t get_num_substitutions () const { return substitutions.size (); }
 
   std::vector<SubstitutionParamMapping> &get_substs () { return substitutions; }
@@ -1040,6 +1071,13 @@ public:
 
   DefId get_id () const { return reference; }
 
+  std::vector<TypeBoundPredicateItem> get_associated_type_items ();
+
+  size_t get_num_associated_bindings () const override final;
+
+  TypeBoundPredicateItem
+  lookup_associated_type (const std::string &search) override final;
+
 private:
   DefId reference;
   Location locus;
index ecb10d81f65f37148dceef0fe48b62458f02770c..57ff17ffcdc8dc606ce568761cf6f0dd5dfa7afe 100644 (file)
@@ -1,10 +1,12 @@
+// { dg-options "-fsyntax-only" }
 trait Foo {
     type Bar;
 }
 
 trait Copy {}
 
-
-fn c<F: Foo<Bar: Foo>>() where F::Bar: Copy { // { dg-warning "function is never used: 'c'" }
+fn c<F: Foo<Bar: Foo>>()
+where
+    F::Bar: Copy,
+{
 }
-
diff --git a/gcc/testsuite/rust/execute/torture/issue-1720.rs b/gcc/testsuite/rust/execute/torture/issue-1720.rs
new file mode 100644 (file)
index 0000000..771d7ee
--- /dev/null
@@ -0,0 +1,26 @@
+mod core {
+    mod ops {
+        #[lang = "add"]
+        pub trait Add<Rhs = Self> {
+            type Output;
+
+            fn add(self, rhs: Rhs) -> Self::Output;
+        }
+    }
+}
+
+impl core::ops::Add for i32 {
+    type Output = i32;
+
+    fn add(self, rhs: i32) -> Self::Output {
+        self + rhs
+    }
+}
+
+pub fn foo<T: core::ops::Add<Output = i32>>(a: T) -> i32 {
+    a + a
+}
+
+pub fn main() -> i32 {
+    foo(1) - 2
+}