]> git.ipfire.org Git - thirdparty/gcc.git/commitdiff
gccrs: derive(PartialEq): Implement proper discriminant comparison
authorArthur Cohen <arthur.cohen@embecosm.com>
Tue, 22 Apr 2025 20:21:01 +0000 (22:21 +0200)
committerArthur Cohen <arthur.cohen@embecosm.com>
Tue, 5 Aug 2025 14:36:48 +0000 (16:36 +0200)
And use the new EnumMatchBuilder class to do so.

gcc/rust/ChangeLog:

* expand/rust-derive-partial-eq.cc (DerivePartialEq::eq_fn): Change signature.
(DerivePartialEq::visit_tuple): Use new eq_fn API.
(DerivePartialEq::visit_struct): Likewise.
(DerivePartialEq::visit_enum): Implement proper discriminant comparison.
* expand/rust-derive-partial-eq.h: Change eq_fn signature.

gcc/testsuite/ChangeLog:

* rust/execute/torture/derive-partialeq2.rs: Add declaration for
discriminant_value.

gcc/rust/expand/rust-derive-partial-eq.cc
gcc/rust/expand/rust-derive-partial-eq.h
gcc/testsuite/rust/execute/torture/derive-partialeq2.rs

index 1f2fa35b284e7320bae506fbc815459994be10a9..ea6a995c4f100fbc88a9121fa68fe5047067bb96 100644 (file)
@@ -64,11 +64,9 @@ DerivePartialEq::partialeq_impls (
 }
 
 std::unique_ptr<AssociatedItem>
-DerivePartialEq::eq_fn (std::unique_ptr<Expr> &&cmp_expression,
+DerivePartialEq::eq_fn (std::unique_ptr<BlockExpr> &&block,
                        std::string type_name)
 {
-  auto block = builder.block (tl::nullopt, std::move (cmp_expression));
-
   auto self_type
     = std::unique_ptr<TypeNoBounds> (new TypePath (builder.type_path ("Self")));
 
@@ -118,7 +116,8 @@ DerivePartialEq::visit_tuple (TupleStruct &item)
   auto type_name = item.get_struct_name ().as_string ();
   auto fields = SelfOther::indexes (builder, item.get_fields ());
 
-  auto fn = eq_fn (build_eq_expression (std::move (fields)), type_name);
+  auto fn = eq_fn (builder.block (build_eq_expression (std::move (fields))),
+                  type_name);
 
   expanded
     = partialeq_impls (std::move (fn), type_name, item.get_generic_params ());
@@ -130,7 +129,8 @@ DerivePartialEq::visit_struct (StructStruct &item)
   auto type_name = item.get_struct_name ().as_string ();
   auto fields = SelfOther::fields (builder, item.get_fields ());
 
-  auto fn = eq_fn (build_eq_expression (std::move (fields)), type_name);
+  auto fn = eq_fn (builder.block (build_eq_expression (std::move (fields))),
+                  type_name);
 
   expanded
     = partialeq_impls (std::move (fn), type_name, item.get_generic_params ());
@@ -270,46 +270,58 @@ DerivePartialEq::visit_enum (Enum &item)
   auto cases = std::vector<MatchCase> ();
   auto type_name = item.get_identifier ().as_string ();
 
+  auto eq_expr_fn = [this] (std::vector<SelfOther> &&fields) {
+    return build_eq_expression (std::move (fields));
+  };
+
+  auto let_sd
+    = builder.discriminant_value (DerivePartialEq::self_discr, "self");
+  auto let_od
+    = builder.discriminant_value (DerivePartialEq::other_discr, "other");
+
+  auto discr_cmp
+    = builder.comparison_expr (builder.identifier (DerivePartialEq::self_discr),
+                              builder.identifier (
+                                DerivePartialEq::other_discr),
+                              ComparisonOperator::EQUAL);
+
   for (auto &variant : item.get_variants ())
     {
       auto variant_path
        = builder.variant_path (type_name,
                                variant->get_identifier ().as_string ());
 
+      auto enum_builder = EnumMatchBuilder (variant_path, eq_expr_fn, builder);
+
       switch (variant->get_enum_item_kind ())
        {
-       case EnumItem::Kind::Identifier:
-       case EnumItem::Kind::Discriminant:
-         cases.emplace_back (match_enum_identifier (variant_path, variant));
-         break;
        case EnumItem::Kind::Tuple:
-         cases.emplace_back (
-           match_enum_tuple (variant_path,
-                             static_cast<EnumItemTuple &> (*variant)));
+         cases.emplace_back (enum_builder.tuple (*variant));
          break;
        case EnumItem::Kind::Struct:
-         cases.emplace_back (
-           match_enum_struct (variant_path,
-                              static_cast<EnumItemStruct &> (*variant)));
+         cases.emplace_back (enum_builder.strukt (*variant));
+         break;
+       case EnumItem::Kind::Identifier:
+       case EnumItem::Kind::Discriminant:
+         // We don't need to do anything for these, as they are handled by the
+         // discriminant value comparison
          break;
        }
     }
 
-  // NOTE: Mention using discriminant_value and skipping that last case, and
-  // instead skipping all identifiers/discriminant enum items and returning
-  // `true` in the wildcard case
-
   // In case the two instances of `Self` don't have the same discriminant,
   // automatically return false.
   cases.emplace_back (
-    builder.match_case (builder.wildcard (), builder.literal_bool (false)));
+    builder.match_case (builder.wildcard (), std::move (discr_cmp)));
 
   auto match
     = builder.match (builder.tuple (vec (builder.identifier ("self"),
                                         builder.identifier ("other"))),
                     std::move (cases));
 
-  auto fn = eq_fn (std::move (match), type_name);
+  auto fn = eq_fn (builder.block (vec (std::move (let_sd), std::move (let_od)),
+                                 std::move (match)),
+                  type_name);
 
   expanded
     = partialeq_impls (std::move (fn), type_name, item.get_generic_params ());
index fdfe4dacb85a56f908820c49b0d8fb2eb71c2181..7985414c252874e1df8bc950f7b06b98112dbcb0 100644 (file)
@@ -44,7 +44,7 @@ private:
     std::unique_ptr<AssociatedItem> &&eq_fn, std::string name,
     const std::vector<std::unique_ptr<GenericParam>> &type_generics);
 
-  std::unique_ptr<AssociatedItem> eq_fn (std::unique_ptr<Expr> &&cmp_expression,
+  std::unique_ptr<AssociatedItem> eq_fn (std::unique_ptr<BlockExpr> &&block,
                                         std::string type_name);
 
   /**
@@ -61,6 +61,9 @@ private:
   MatchCase match_enum_struct (PathInExpression variant_path,
                               const EnumItemStruct &variant);
 
+  constexpr static const char *self_discr = "#self_discr";
+  constexpr static const char *other_discr = "#other_discr";
+
   virtual void visit_struct (StructStruct &item) override;
   virtual void visit_tuple (TupleStruct &item) override;
   virtual void visit_enum (Enum &item) override;
index 70ed7dcd93d9f5ec1bef80ec4e31ddd753698e24..e316017753a2ad38b93a31a913209e742ff334b0 100644 (file)
@@ -2,6 +2,20 @@
 
 #![feature(intrinsics)]
 
+pub mod core {
+    pub mod intrinsics {
+        #[lang = "discriminant_kind"]
+        pub trait DiscriminantKind {
+            #[lang = "discriminant_type"]
+            type Discriminant;
+        }
+
+        extern "rust-intrinsic" {
+            pub fn discriminant_value<T>(v: &T) -> <T as DiscriminantKind>::Discriminant;
+        }
+    }
+}
+
 #[lang = "sized"]
 trait Sized {}