]> git.ipfire.org Git - thirdparty/gcc.git/commitdiff
gccrs: Initial support for Return Position Impl Trait
authorPhilip Herron <herron.philip@googlemail.com>
Mon, 19 May 2025 17:02:13 +0000 (18:02 +0100)
committerPhilip Herron <philip.herron@embecosm.com>
Mon, 26 May 2025 19:18:47 +0000 (19:18 +0000)
This is the initial patch for RPIT, we can build on this to handle the
more complex cases but there are enough distinct changes going on here
that it should just get merged now.

RPIT is really a sneaky generic so for example:

  fn foo() -> impl Bar {
      Baz
  }

This is represented as: fn () -> OpaqueType Bar. But when we handle the
coercion site for Baz on impl Bar when we type resolve the function we
know that the underlying type  is Baz. Note this function is _not_ generic
so its using this special OpaqueType and keeping track of the underlying type
in its ty_ref reference hir-id which will resolve to Baz.

This also means if we have a case where maybe this was in an if statement:

  fn foo(a: i32) -> impl Bar {
      if a > 10 {
        Baz
      } else {
        Qux
      }
  }

The rules of impl Bar is that Baz is handled but Baz and Qux are different
underlying types so this is not allowed. The reason is impl traits are not
generic and although from a programmer perspective the callers dont know what
the underlying type is, the compiler _knows_ what it is. So really when
you call a function and get its return position impl trait the compiler knows
what to do and does all whats nessecary to handle calling functions using that
type etc.

gcc/rust/ChangeLog:

* backend/rust-compile-type.cc (TyTyResolveCompile::visit): we need to resolve the
underlying type
* typecheck/rust-substitution-mapper.cc (SubstMapperInternal::visit): just clone
* typecheck/rust-tyty-call.cc (TypeCheckCallExpr::visit):
ensure we monomphize to get the underlying
* typecheck/rust-tyty.cc (BaseType::destructure): handle opaque types
(OpaqueType::resolve): this is much simpler now
(OpaqueType::handle_substitions): no longer needed
* typecheck/rust-tyty.h: update header
* typecheck/rust-unify.cc (UnifyRules::expect_opaque): unify rules for opaque

gcc/testsuite/ChangeLog:

* rust/compile/bad-rpit1.rs: New test.
* rust/execute/torture/impl_rpit1.rs: New test.
* rust/execute/torture/impl_rpit2.rs: New test.
* rust/execute/torture/impl_rpit3.rs: New test.

Signed-off-by: Philip Herron <herron.philip@googlemail.com>
gcc/rust/backend/rust-compile-type.cc
gcc/rust/typecheck/rust-substitution-mapper.cc
gcc/rust/typecheck/rust-tyty-call.cc
gcc/rust/typecheck/rust-tyty.cc
gcc/rust/typecheck/rust-tyty.h
gcc/rust/typecheck/rust-unify.cc
gcc/testsuite/rust/compile/bad-rpit1.rs [new file with mode: 0644]
gcc/testsuite/rust/execute/torture/impl_rpit1.rs [new file with mode: 0644]
gcc/testsuite/rust/execute/torture/impl_rpit2.rs [new file with mode: 0644]
gcc/testsuite/rust/execute/torture/impl_rpit3.rs [new file with mode: 0644]

index 903d0ce85416ab9683f5c9b9b7c83eb289bc836b..5ca1d82b3d40ee767bc1d7dc3ae5ca85e5402aba 100644 (file)
@@ -755,7 +755,9 @@ TyTyResolveCompile::visit (const TyTy::DynamicObjectType &type)
 void
 TyTyResolveCompile::visit (const TyTy::OpaqueType &type)
 {
-  translated = error_mark_node;
+  rust_assert (type.can_resolve ());
+  auto underlying = type.resolve ();
+  translated = TyTyResolveCompile::compile (ctx, underlying, trait_object_mode);
 }
 
 tree
index 878c4d54a0a802e89ee38deb6a9a1a3e2e56156f..23116ff804018ba668cd51881e56f82d9c53a8b0 100644 (file)
@@ -374,7 +374,7 @@ SubstMapperInternal::visit (TyTy::DynamicObjectType &type)
 void
 SubstMapperInternal::visit (TyTy::OpaqueType &type)
 {
-  resolved = type.handle_substitions (mappings);
+  resolved = type.clone ();
 }
 
 // SubstMapperFromExisting
index 75cf58faf89e5e1647e5ae4f01b5d0bec1cd5b34..0d187f37ee380885278e26bab00252a8c879a186 100644 (file)
@@ -246,7 +246,7 @@ TypeCheckCallExpr::visit (FnType &type)
     }
 
   type.monomorphize ();
-  resolved = type.get_return_type ()->clone ();
+  resolved = type.get_return_type ()->monomorphized_clone ();
 }
 
 void
index 09a9b97354be6e668bc00e2a6dd0da2be5acd7da..c35a6c1294891fa2f234dcc2b63991725b7c65a1 100644 (file)
@@ -546,17 +546,14 @@ BaseType::destructure () const
        {
          x = p->get ();
        }
-      // else if (auto p = x->try_as<const OpaqueType> ())
-      //   {
-      //     auto pr = p->resolve ();
-
-      //     rust_debug ("XXXXXX")
-
-      //     if (pr == x)
-      //       return pr;
+      else if (auto p = x->try_as<const OpaqueType> ())
+       {
+         auto pr = p->resolve ();
+         if (pr == x)
+           return pr;
 
-      //     x = pr;
-      //   }
+         x = pr;
+       }
       else
        {
          return x;
@@ -3624,28 +3621,7 @@ BaseType *
 OpaqueType::resolve () const
 {
   TyVar var (get_ty_ref ());
-  BaseType *r = var.get_tyty ();
-
-  while (r->get_kind () == TypeKind::OPAQUE)
-    {
-      OpaqueType *rr = static_cast<OpaqueType *> (r);
-      if (!rr->can_resolve ())
-       break;
-
-      TyVar v (rr->get_ty_ref ());
-      BaseType *n = v.get_tyty ();
-
-      // fix infinite loop
-      if (r == n)
-       break;
-
-      r = n;
-    }
-
-  if (r->get_kind () == TypeKind::OPAQUE && (r->get_ref () == r->get_ty_ref ()))
-    return TyVar (r->get_ty_ref ()).get_tyty ();
-
-  return r;
+  return var.get_tyty ();
 }
 
 bool
@@ -3655,41 +3631,9 @@ OpaqueType::is_equal (const BaseType &other) const
   if (can_resolve () != other2.can_resolve ())
     return false;
 
-  if (can_resolve ())
-    return resolve ()->can_eq (other2.resolve (), false);
-
   return bounds_compatible (other, UNDEF_LOCATION, false);
 }
 
-OpaqueType *
-OpaqueType::handle_substitions (SubstitutionArgumentMappings &subst_mappings)
-{
-  // SubstitutionArg arg = SubstitutionArg::error ();
-  // bool ok = subst_mappings.get_argument_for_symbol (this, &arg);
-  // if (!ok || arg.is_error ())
-  //   return this;
-
-  // OpaqueType *p = static_cast<OpaqueType *> (clone ());
-  // subst_mappings.on_param_subst (*p, arg);
-
-  // // there are two cases one where we substitute directly to a new PARAM and
-  // // otherwise
-  // if (arg.get_tyty ()->get_kind () == TyTy::TypeKind::PARAM)
-  //   {
-  //     p->set_ty_ref (arg.get_tyty ()->get_ref ());
-  //     return p;
-  //   }
-
-  // // this is the new subst that this needs to pass
-  // p->set_ref (mappings.get_next_hir_id ());
-  // p->set_ty_ref (arg.get_tyty ()->get_ref ());
-
-  // return p;
-
-  rust_unreachable ();
-  return nullptr;
-}
-
 // StrType
 
 StrType::StrType (HirId ref, std::set<HirId> refs)
index b14e9f25221de58a1bfe2d9bdea84c0767ea5f58..e35585c6a1301d4c4fdd5b55c572bf1c65442b6e 100644 (file)
@@ -441,8 +441,6 @@ public:
   std::string get_name () const override final;
 
   bool is_equal (const BaseType &other) const override;
-
-  OpaqueType *handle_substitions (SubstitutionArgumentMappings &mappings);
 };
 
 class StructFieldType
index 4cfc197c430ad69d91c72ce53a06b90f4029134c..219db9a6df9e861e6f3f7dcc474867b7f49df458 100644 (file)
@@ -1788,59 +1788,51 @@ UnifyRules::expect_closure (TyTy::ClosureType *ltype, TyTy::BaseType *rtype)
 TyTy::BaseType *
 UnifyRules::expect_opaque (TyTy::OpaqueType *ltype, TyTy::BaseType *rtype)
 {
-  switch (rtype->get_kind ())
+  if (rtype->is<TyTy::OpaqueType> ())
     {
-      case TyTy::INFER: {
-       TyTy::InferType *r = static_cast<TyTy::InferType *> (rtype);
-       bool is_valid
-         = r->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL;
-       if (is_valid)
-         return ltype->clone ();
-      }
-      break;
-
-      case TyTy::OPAQUE: {
-       auto &type = *static_cast<TyTy::OpaqueType *> (rtype);
-       if (ltype->num_specified_bounds () != type.num_specified_bounds ())
-         {
-           return new TyTy::ErrorType (0);
-         }
+      TyTy::OpaqueType *ro = rtype->as<TyTy::OpaqueType> ();
+      if (!ltype->is_equal (*ro))
+       return new TyTy::ErrorType (0);
 
-       if (!ltype->bounds_compatible (type, locus, true))
-         {
+      if (ltype->can_resolve () && ro->can_resolve ())
+       {
+         auto lr = ltype->resolve ();
+         auto rr = ro->resolve ();
+
+         auto res = UnifyRules::Resolve (TyTy::TyWithLocation (lr),
+                                         TyTy::TyWithLocation (rr), locus,
+                                         commit_flag, false /* emit_error */,
+                                         infer_flag, commits, infers);
+         if (res->get_kind () == TyTy::TypeKind::ERROR)
            return new TyTy::ErrorType (0);
-         }
-
-       return ltype->clone ();
-      }
-      break;
-
-    case TyTy::CLOSURE:
-    case TyTy::SLICE:
-    case TyTy::PARAM:
-    case TyTy::POINTER:
-    case TyTy::STR:
-    case TyTy::ADT:
-    case TyTy::REF:
-    case TyTy::ARRAY:
-    case TyTy::FNDEF:
-    case TyTy::FNPTR:
-    case TyTy::TUPLE:
-    case TyTy::BOOL:
-    case TyTy::CHAR:
-    case TyTy::INT:
-    case TyTy::UINT:
-    case TyTy::FLOAT:
-    case TyTy::USIZE:
-    case TyTy::ISIZE:
-    case TyTy::NEVER:
-    case TyTy::PLACEHOLDER:
-    case TyTy::PROJECTION:
-    case TyTy::DYNAMIC:
-    case TyTy::ERROR:
-      return new TyTy::ErrorType (0);
+       }
+      else if (ltype->can_resolve ())
+       {
+         auto lr = ltype->resolve ();
+         ro->set_ty_ref (lr->get_ref ());
+       }
+      else if (ro->can_resolve ())
+       {
+         auto rr = ro->resolve ();
+         ltype->set_ty_ref (rr->get_ref ());
+       }
     }
-  return new TyTy::ErrorType (0);
+  else if (ltype->can_resolve ())
+    {
+      auto underly = ltype->resolve ();
+      auto res = UnifyRules::Resolve (TyTy::TyWithLocation (underly),
+                                     TyTy::TyWithLocation (rtype), locus,
+                                     commit_flag, false /* emit_error */,
+                                     infer_flag, commits, infers);
+      if (res->get_kind () == TyTy::TypeKind::ERROR)
+       return new TyTy::ErrorType (0);
+    }
+  else
+    {
+      ltype->set_ty_ref (rtype->get_ref ());
+    }
+
+  return ltype;
 }
 
 } // namespace Resolver
diff --git a/gcc/testsuite/rust/compile/bad-rpit1.rs b/gcc/testsuite/rust/compile/bad-rpit1.rs
new file mode 100644 (file)
index 0000000..d8c21b1
--- /dev/null
@@ -0,0 +1,26 @@
+#[lang = "sized"]
+trait Sized {}
+
+trait Foo {
+    fn id(&self) -> i32;
+}
+
+struct A;
+struct B;
+
+impl Foo for A {
+    fn id(&self) -> i32 {
+        1
+    }
+}
+
+impl Foo for B {
+    fn id(&self) -> i32 {
+        2
+    }
+}
+
+fn make_foo(cond: bool) -> impl Foo {
+    if cond { A } else { B }
+    // { dg-error "mismatched types, expected .A. but got .B. .E0308." "" { target *-*-* } .-1 }
+}
diff --git a/gcc/testsuite/rust/execute/torture/impl_rpit1.rs b/gcc/testsuite/rust/execute/torture/impl_rpit1.rs
new file mode 100644 (file)
index 0000000..8ce5f21
--- /dev/null
@@ -0,0 +1,28 @@
+#[lang = "sized"]
+trait Sized {}
+
+trait Foo {
+    fn id(&self) -> i32;
+}
+
+struct Thing(i32);
+
+impl Foo for Thing {
+    fn id(&self) -> i32 {
+        self.0
+    }
+}
+
+fn make_thing(a: i32) -> impl Foo {
+    Thing(a)
+}
+
+fn use_foo(f: impl Foo) -> i32 {
+    f.id()
+}
+
+fn main() -> i32 {
+    let value = make_thing(42);
+    let val = use_foo(value);
+    val - 42
+}
diff --git a/gcc/testsuite/rust/execute/torture/impl_rpit2.rs b/gcc/testsuite/rust/execute/torture/impl_rpit2.rs
new file mode 100644 (file)
index 0000000..f7cbbb6
--- /dev/null
@@ -0,0 +1,36 @@
+#[lang = "sized"]
+trait Sized {}
+
+trait Foo {
+    fn id(&self) -> i32;
+}
+
+struct Thing(i32);
+
+impl Thing {
+    fn double(&self) -> i32 {
+        // { dg-warning "associated function is never used: .double." "" { target *-*-* } .-1 }
+        self.0 * 2
+    }
+}
+
+impl Foo for Thing {
+    fn id(&self) -> i32 {
+        self.0
+    }
+}
+
+fn make_thing(a: i32) -> impl Foo {
+    Thing(a)
+}
+
+fn use_foo(f: impl Foo) -> i32 {
+    f.id()
+}
+
+fn main() -> i32 {
+    let value = make_thing(21);
+    let id = use_foo(value);
+
+    id - 21
+}
diff --git a/gcc/testsuite/rust/execute/torture/impl_rpit3.rs b/gcc/testsuite/rust/execute/torture/impl_rpit3.rs
new file mode 100644 (file)
index 0000000..dd68eb2
--- /dev/null
@@ -0,0 +1,25 @@
+#[lang = "sized"]
+trait Sized {}
+
+trait Foo {
+    fn id(&self) -> i32;
+}
+
+struct Thing(i32);
+
+impl Foo for Thing {
+    fn id(&self) -> i32 {
+        self.0
+    }
+}
+
+fn make_thing() -> impl Foo {
+    Thing(99)
+}
+
+fn main() -> i32 {
+    let v = make_thing();
+    let r = &v;
+    let val = r.id();
+    val - 99
+}