]> git.ipfire.org Git - thirdparty/gcc.git/commitdiff
gccrs: add new apply_primtiive_type_hint to inference variables
authorPhilip Herron <herron.philip@googlemail.com>
Fri, 17 Mar 2023 18:09:53 +0000 (18:09 +0000)
committerArthur Cohen <arthur.cohen@embecosm.com>
Tue, 16 Jan 2024 17:21:12 +0000 (18:21 +0100)
In the senario where you infer types via unify_site_and but choose to not
commit the result you can end up with coercion to infer the result later on
which does not get fully replaced resulting in a stray inference variable
that can be left alone as a general inference variable leading to missing
type context info. This patch gives support to add type hints to inference
variables so they can be defaulted correctly in more complex cases. The
old system relied on the unification result commiting and overriding the
inference variables so they dissapear out of the current typeing context.

This was needed to fix #1981 where it is valid to inject inference
variables here. This lead to a regression in a few of the complex generic
trait test cases such as execute/torture/traits9.rs which had the wrong
argument type and defaulted wrongly to i32 instead of isize.

gcc/rust/ChangeLog:

* typecheck/rust-hir-type-check-base.cc (TypeCheckBase::resolve_literal): fix ctor
* typecheck/rust-hir-type-check-stmt.cc (TypeCheckStmt::visit): likewise
* typecheck/rust-hir-type-check-type.cc (TypeCheckType::visit): likewise
* typecheck/rust-typecheck-context.cc (TypeCheckContext::push_new_loop_context): likewise
* typecheck/rust-tyty-util.cc (TyVar::get_implicit_infer_var): likewise
* typecheck/rust-tyty.cc (InferType::InferType): new ctor with type hint
(InferType::clone): fix ctor
(InferType::apply_primitive_type_hint): new function to apply possible hint
* typecheck/rust-tyty.h: update prototypes
* typecheck/rust-unify.cc (UnifyRules::expect_inference_variable): apply type hints
(UnifyRules::expect_bool): likewise
(UnifyRules::expect_char): likewise
(UnifyRules::expect_int): likewise
(UnifyRules::expect_uint): likewise
(UnifyRules::expect_float): likewise
(UnifyRules::expect_isize): likewise
(UnifyRules::expect_usize): likewise

Signed-off-by: Philip Herron <herron.philip@googlemail.com>
gcc/rust/typecheck/rust-hir-type-check-base.cc
gcc/rust/typecheck/rust-hir-type-check-stmt.cc
gcc/rust/typecheck/rust-hir-type-check-type.cc
gcc/rust/typecheck/rust-typecheck-context.cc
gcc/rust/typecheck/rust-tyty-util.cc
gcc/rust/typecheck/rust-tyty.cc
gcc/rust/typecheck/rust-tyty.h
gcc/rust/typecheck/rust-unify.cc

index 4d53526fdc93b2d0dfa3bddaf5aa344f58cf5b58..6ef64e618a47aa9a4738d138d04c8a6e285d2562 100644 (file)
@@ -165,6 +165,7 @@ TypeCheckBase::resolve_literal (const Analysis::NodeMapping &expr_mappings,
            infered
              = new TyTy::InferType (expr_mappings.get_hirid (),
                                     TyTy::InferType::InferTypeKind::INTEGRAL,
+                                    TyTy::InferType::TypeHint::Default (),
                                     locus);
            break;
          }
@@ -189,6 +190,7 @@ TypeCheckBase::resolve_literal (const Analysis::NodeMapping &expr_mappings,
            infered
              = new TyTy::InferType (expr_mappings.get_hirid (),
                                     TyTy::InferType::InferTypeKind::FLOAT,
+                                    TyTy::InferType::TypeHint::Default (),
                                     locus);
            break;
          }
index d1fb45a8dbe32a6df2adf9918e287243c341ca37..17068a95bf139a3a8fe1da302d84be79462f5ac7 100644 (file)
@@ -127,11 +127,11 @@ TypeCheckStmt::visit (HIR::LetStmt &stmt)
       // let x;
       else
        {
-         TypeCheckPattern::Resolve (
-           &stmt_pattern,
-           new TyTy::InferType (
-             stmt_pattern.get_pattern_mappings ().get_hirid (),
-             TyTy::InferType::InferTypeKind::GENERAL, stmt.get_locus ()));
+         auto infer = new TyTy::InferType (
+           stmt_pattern.get_pattern_mappings ().get_hirid (),
+           TyTy::InferType::InferTypeKind::GENERAL,
+           TyTy::InferType::TypeHint::Default (), stmt.get_locus ());
+         TypeCheckPattern::Resolve (&stmt_pattern, infer);
        }
     }
 }
index 50a39fdf308e25d547cf6cba87e50408a258e409..e5564ab4ee8b09a83800b8c218c19f89189cfd2b 100644 (file)
@@ -635,6 +635,7 @@ TypeCheckType::visit (HIR::InferredType &type)
 {
   translated = new TyTy::InferType (type.get_mappings ().get_hirid (),
                                    TyTy::InferType::InferTypeKind::GENERAL,
+                                   TyTy::InferType::TypeHint::Default (),
                                    type.get_locus ());
 }
 
index 093bc0a702ecccecd8cef2c1a6147835f93a882f..7b2c96cdce289f55f27da5832d376a9d7a7a399a 100644 (file)
@@ -184,7 +184,8 @@ void
 TypeCheckContext::push_new_loop_context (HirId id, Location locus)
 {
   TyTy::BaseType *infer_var
-    = new TyTy::InferType (id, TyTy::InferType::InferTypeKind::GENERAL, locus);
+    = new TyTy::InferType (id, TyTy::InferType::InferTypeKind::GENERAL,
+                          TyTy::InferType::TypeHint::Default (), locus);
   loop_type_stack.push_back (infer_var);
 }
 
index a333cf41351adef802b0ba502edda1c9fe0180cf..7002efd91987e139e2d4564d65d5b0ca9642d4c3 100644 (file)
@@ -48,7 +48,8 @@ TyVar::get_implicit_infer_var (Location locus)
   auto context = Resolver::TypeCheckContext::get ();
 
   InferType *infer = new InferType (mappings->get_next_hir_id (),
-                                   InferType::InferTypeKind::GENERAL, locus);
+                                   InferType::InferTypeKind::GENERAL,
+                                   InferType::TypeHint::Default (), locus);
   context->insert_type (Analysis::NodeMapping (mappings->get_current_crate (),
                                               UNKNOWN_NODEID,
                                               infer->get_ref (),
index 8c542c5eb80e5b2c7ad7835931ed2aeded55e9a7..1c673181e821b2bd5db6300ca3123e90cffde796 100644 (file)
@@ -935,18 +935,18 @@ BaseType::needs_generic_substitutions () const
 
 // InferType
 
-InferType::InferType (HirId ref, InferTypeKind infer_kind, Location locus,
-                     std::set<HirId> refs)
+InferType::InferType (HirId ref, InferTypeKind infer_kind, TypeHint hint,
+                     Location locus, std::set<HirId> refs)
   : BaseType (ref, ref, TypeKind::INFER,
              {Resolver::CanonicalPath::create_empty (), locus}, refs),
-    infer_kind (infer_kind)
+    infer_kind (infer_kind), default_hint (hint)
 {}
 
 InferType::InferType (HirId ref, HirId ty_ref, InferTypeKind infer_kind,
-                     Location locus, std::set<HirId> refs)
+                     TypeHint hint, Location locus, std::set<HirId> refs)
   : BaseType (ref, ty_ref, TypeKind::INFER,
              {Resolver::CanonicalPath::create_empty (), locus}, refs),
-    infer_kind (infer_kind)
+    infer_kind (infer_kind), default_hint (hint)
 {}
 
 InferType::InferTypeKind
@@ -1012,7 +1012,7 @@ InferType::clone () const
 
   InferType *clone
     = new InferType (mappings->get_next_hir_id (), get_infer_kind (),
-                    get_ident ().locus, get_combined_refs ());
+                    default_hint, get_ident ().locus, get_combined_refs ());
 
   context->insert_type (Analysis::NodeMapping (mappings->get_current_crate (),
                                               UNKNOWN_NODEID,
@@ -1033,24 +1033,217 @@ InferType::default_type (BaseType **type) const
 {
   auto context = Resolver::TypeCheckContext::get ();
   bool ok = false;
-  switch (infer_kind)
+
+  if (default_hint.kind == TypeKind::ERROR)
     {
-    case GENERAL:
+      switch (infer_kind)
+       {
+       case GENERAL:
+         return false;
+
+         case INTEGRAL: {
+           ok = context->lookup_builtin ("i32", type);
+           rust_assert (ok);
+           return ok;
+         }
+
+         case FLOAT: {
+           ok = context->lookup_builtin ("f64", type);
+           rust_assert (ok);
+           return ok;
+         }
+       }
+      return false;
+    }
+
+  switch (default_hint.kind)
+    {
+    case ISIZE:
+      ok = context->lookup_builtin ("isize", type);
+      rust_assert (ok);
+      return ok;
+
+    case USIZE:
+      ok = context->lookup_builtin ("usize", type);
+      rust_assert (ok);
+      return ok;
+
+    case INT:
+      switch (default_hint.szhint)
+       {
+       case TypeHint::SizeHint::S8:
+         ok = context->lookup_builtin ("i8", type);
+         rust_assert (ok);
+         return ok;
+
+       case TypeHint::SizeHint::S16:
+         ok = context->lookup_builtin ("i16", type);
+         rust_assert (ok);
+         return ok;
+
+       case TypeHint::SizeHint::S32:
+         ok = context->lookup_builtin ("i32", type);
+         rust_assert (ok);
+         return ok;
+
+       case TypeHint::SizeHint::S64:
+         ok = context->lookup_builtin ("i64", type);
+         rust_assert (ok);
+         return ok;
+
+       case TypeHint::SizeHint::S128:
+         ok = context->lookup_builtin ("i128", type);
+         rust_assert (ok);
+         return ok;
+
+       default:
+         return false;
+       }
+      break;
+
+    case UINT:
+      switch (default_hint.szhint)
+       {
+       case TypeHint::SizeHint::S8:
+         ok = context->lookup_builtin ("u8", type);
+         rust_assert (ok);
+         return ok;
+
+       case TypeHint::SizeHint::S16:
+         ok = context->lookup_builtin ("u16", type);
+         rust_assert (ok);
+         return ok;
+
+       case TypeHint::SizeHint::S32:
+         ok = context->lookup_builtin ("u32", type);
+         rust_assert (ok);
+         return ok;
+
+       case TypeHint::SizeHint::S64:
+         ok = context->lookup_builtin ("u64", type);
+         rust_assert (ok);
+         return ok;
+
+       case TypeHint::SizeHint::S128:
+         ok = context->lookup_builtin ("u128", type);
+         rust_assert (ok);
+         return ok;
+
+       default:
+         return false;
+       }
+      break;
+
+    case TypeKind::FLOAT:
+      switch (default_hint.szhint)
+       {
+       case TypeHint::SizeHint::S32:
+         ok = context->lookup_builtin ("f32", type);
+         rust_assert (ok);
+         return ok;
+
+       case TypeHint::SizeHint::S64:
+         ok = context->lookup_builtin ("f64", type);
+         rust_assert (ok);
+         return ok;
+
+       default:
+         return false;
+       }
+      break;
+
+    default:
       return false;
+    }
 
-      case INTEGRAL: {
-       ok = context->lookup_builtin ("i32", type);
-       rust_assert (ok);
-       return ok;
+  return false;
+}
+
+void
+InferType::apply_primitive_type_hint (const BaseType &hint)
+{
+  switch (hint.get_kind ())
+    {
+    case ISIZE:
+    case USIZE:
+      infer_kind = INTEGRAL;
+      default_hint.kind = hint.get_kind ();
+      break;
+
+      case INT: {
+       infer_kind = INTEGRAL;
+       const IntType &i = static_cast<const IntType &> (hint);
+       default_hint.kind = hint.get_kind ();
+       default_hint.shint = TypeHint::SignedHint::SIGNED;
+       switch (i.get_int_kind ())
+         {
+         case IntType::I8:
+           default_hint.szhint = TypeHint::SizeHint::S8;
+           break;
+         case IntType::I16:
+           default_hint.szhint = TypeHint::SizeHint::S16;
+           break;
+         case IntType::I32:
+           default_hint.szhint = TypeHint::SizeHint::S32;
+           break;
+         case IntType::I64:
+           default_hint.szhint = TypeHint::SizeHint::S64;
+           break;
+         case IntType::I128:
+           default_hint.szhint = TypeHint::SizeHint::S128;
+           break;
+         }
+      }
+      break;
+
+      case UINT: {
+       infer_kind = INTEGRAL;
+       const UintType &i = static_cast<const UintType &> (hint);
+       default_hint.kind = hint.get_kind ();
+       default_hint.shint = TypeHint::SignedHint::UNSIGNED;
+       switch (i.get_uint_kind ())
+         {
+         case UintType::U8:
+           default_hint.szhint = TypeHint::SizeHint::S8;
+           break;
+         case UintType::U16:
+           default_hint.szhint = TypeHint::SizeHint::S16;
+           break;
+         case UintType::U32:
+           default_hint.szhint = TypeHint::SizeHint::S32;
+           break;
+         case UintType::U64:
+           default_hint.szhint = TypeHint::SizeHint::S64;
+           break;
+         case UintType::U128:
+           default_hint.szhint = TypeHint::SizeHint::S128;
+           break;
+         }
       }
+      break;
+
+      case TypeKind::FLOAT: {
+       infer_kind = FLOAT;
+       default_hint.shint = TypeHint::SignedHint::SIGNED;
+       default_hint.kind = hint.get_kind ();
+       const FloatType &i = static_cast<const FloatType &> (hint);
+       switch (i.get_float_kind ())
+         {
+         case FloatType::F32:
+           default_hint.szhint = TypeHint::SizeHint::S32;
+           break;
 
-      case FLOAT: {
-       ok = context->lookup_builtin ("f64", type);
-       rust_assert (ok);
-       return ok;
+         case FloatType::F64:
+           default_hint.szhint = TypeHint::SizeHint::S64;
+           break;
+         }
       }
+      break;
+
+    default:
+      // TODO bool, char, never??
+      break;
     }
-  return false;
 }
 
 // ErrorType
index 8fd680d8d81025a50bd9545b3e49e534c7cbb80e..6bbaae7f4415acb9f09635af534c78a82471c550 100644 (file)
@@ -188,12 +188,41 @@ public:
     FLOAT
   };
 
-  InferType (HirId ref, InferTypeKind infer_kind, Location locus,
-            std::set<HirId> refs = std::set<HirId> ());
+  struct TypeHint
+  {
+    enum SignedHint
+    {
+      SIGNED,
+      UNSIGNED,
+
+      UNKNOWN
+    };
+    enum SizeHint
+    {
+      S8,
+      S16,
+      S32,
+      S64,
+      S128,
+      SUNKNOWN
+    };
+
+    TyTy::TypeKind kind;
+    SignedHint shint;
+    SizeHint szhint;
+
+    static TypeHint Default ()
+    {
+      return TypeHint{TypeKind::ERROR, UNKNOWN, SUNKNOWN};
+    }
+  };
 
-  InferType (HirId ref, HirId ty_ref, InferTypeKind infer_kind, Location locus,
+  InferType (HirId ref, InferTypeKind infer_kind, TypeHint hint, Location locus,
             std::set<HirId> refs = std::set<HirId> ());
 
+  InferType (HirId ref, HirId ty_ref, InferTypeKind infer_kind, TypeHint hint,
+            Location locus, std::set<HirId> refs = std::set<HirId> ());
+
   void accept_vis (TyVisitor &vis) override;
   void accept_vis (TyConstVisitor &vis) const override;
 
@@ -209,8 +238,11 @@ public:
 
   bool default_type (BaseType **type) const;
 
+  void apply_primitive_type_hint (const TyTy::BaseType &hint);
+
 private:
   InferTypeKind infer_kind;
+  TypeHint default_hint;
 };
 
 class ErrorType : public BaseType
index 3af261db59e90c90e7bc9fa6797e9e6324bd8bde..6e39e98dfb1df83f5feda948bcb131d77619bf7e 100644 (file)
@@ -334,7 +334,10 @@ UnifyRules::expect_inference_variable (TyTy::InferType *ltype,
                        || (ltype->get_infer_kind ()
                            == TyTy::InferType::InferTypeKind::INTEGRAL);
        if (is_valid)
-         return rtype->clone ();
+         {
+           ltype->apply_primitive_type_hint (*rtype);
+           return rtype->clone ();
+         }
       }
       break;
 
@@ -344,7 +347,10 @@ UnifyRules::expect_inference_variable (TyTy::InferType *ltype,
                        || (ltype->get_infer_kind ()
                            == TyTy::InferType::InferTypeKind::FLOAT);
        if (is_valid)
-         return rtype->clone ();
+         {
+           ltype->apply_primitive_type_hint (*rtype);
+           return rtype->clone ();
+         }
       }
       break;
 
@@ -1133,7 +1139,10 @@ UnifyRules::expect_bool (TyTy::BoolType *ltype, TyTy::BaseType *rtype)
        bool is_valid
          = r->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL;
        if (is_valid)
-         return ltype->clone ();
+         {
+           r->apply_primitive_type_hint (*ltype);
+           return ltype->clone ();
+         }
       }
       break;
 
@@ -1177,7 +1186,10 @@ UnifyRules::expect_char (TyTy::CharType *ltype, TyTy::BaseType *rtype)
        bool is_valid
          = r->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL;
        if (is_valid)
-         return ltype->clone ();
+         {
+           r->apply_primitive_type_hint (*ltype);
+           return ltype->clone ();
+         }
       }
       break;
 
@@ -1222,7 +1234,10 @@ UnifyRules::expect_int (TyTy::IntType *ltype, TyTy::BaseType *rtype)
          = r->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL
            || r->get_infer_kind () == TyTy::InferType::InferTypeKind::INTEGRAL;
        if (is_valid)
-         return ltype->clone ();
+         {
+           r->apply_primitive_type_hint (*ltype);
+           return ltype->clone ();
+         }
       }
       break;
 
@@ -1273,7 +1288,10 @@ UnifyRules::expect_uint (TyTy::UintType *ltype, TyTy::BaseType *rtype)
          = r->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL
            || r->get_infer_kind () == TyTy::InferType::InferTypeKind::INTEGRAL;
        if (is_valid)
-         return ltype->clone ();
+         {
+           r->apply_primitive_type_hint (*ltype);
+           return ltype->clone ();
+         }
       }
       break;
 
@@ -1324,7 +1342,10 @@ UnifyRules::expect_float (TyTy::FloatType *ltype, TyTy::BaseType *rtype)
          = r->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL
            || r->get_infer_kind () == TyTy::InferType::InferTypeKind::FLOAT;
        if (is_valid)
-         return ltype->clone ();
+         {
+           r->apply_primitive_type_hint (*ltype);
+           return ltype->clone ();
+         }
       }
       break;
 
@@ -1374,7 +1395,10 @@ UnifyRules::expect_isize (TyTy::ISizeType *ltype, TyTy::BaseType *rtype)
        bool is_valid
          = r->get_infer_kind () != TyTy::InferType::InferTypeKind::FLOAT;
        if (is_valid)
-         return ltype->clone ();
+         {
+           r->apply_primitive_type_hint (*ltype);
+           return ltype->clone ();
+         }
       }
       break;
 
@@ -1418,7 +1442,10 @@ UnifyRules::expect_usize (TyTy::USizeType *ltype, TyTy::BaseType *rtype)
        bool is_valid
          = r->get_infer_kind () != TyTy::InferType::InferTypeKind::FLOAT;
        if (is_valid)
-         return ltype->clone ();
+         {
+           r->apply_primitive_type_hint (*ltype);
+           return ltype->clone ();
+         }
       }
       break;