From: Arthur Cohen Date: Tue, 22 Apr 2025 20:21:01 +0000 (+0200) Subject: gccrs: derive(PartialEq): Implement proper discriminant comparison X-Git-Url: http://git.ipfire.org/gitweb/gitweb.cgi?a=commitdiff_plain;h=7e48be1af720be6598dbf12c0abe6edbd9e92a47;p=thirdparty%2Fgcc.git gccrs: derive(PartialEq): Implement proper discriminant comparison And use the new EnumMatchBuilder class to do so. gcc/rust/ChangeLog: * expand/rust-derive-partial-eq.cc (DerivePartialEq::eq_fn): Change signature. (DerivePartialEq::visit_tuple): Use new eq_fn API. (DerivePartialEq::visit_struct): Likewise. (DerivePartialEq::visit_enum): Implement proper discriminant comparison. * expand/rust-derive-partial-eq.h: Change eq_fn signature. gcc/testsuite/ChangeLog: * rust/execute/torture/derive-partialeq2.rs: Add declaration for discriminant_value. --- diff --git a/gcc/rust/expand/rust-derive-partial-eq.cc b/gcc/rust/expand/rust-derive-partial-eq.cc index 1f2fa35b284..ea6a995c4f1 100644 --- a/gcc/rust/expand/rust-derive-partial-eq.cc +++ b/gcc/rust/expand/rust-derive-partial-eq.cc @@ -64,11 +64,9 @@ DerivePartialEq::partialeq_impls ( } std::unique_ptr -DerivePartialEq::eq_fn (std::unique_ptr &&cmp_expression, +DerivePartialEq::eq_fn (std::unique_ptr &&block, std::string type_name) { - auto block = builder.block (tl::nullopt, std::move (cmp_expression)); - auto self_type = std::unique_ptr (new TypePath (builder.type_path ("Self"))); @@ -118,7 +116,8 @@ DerivePartialEq::visit_tuple (TupleStruct &item) auto type_name = item.get_struct_name ().as_string (); auto fields = SelfOther::indexes (builder, item.get_fields ()); - auto fn = eq_fn (build_eq_expression (std::move (fields)), type_name); + auto fn = eq_fn (builder.block (build_eq_expression (std::move (fields))), + type_name); expanded = partialeq_impls (std::move (fn), type_name, item.get_generic_params ()); @@ -130,7 +129,8 @@ DerivePartialEq::visit_struct (StructStruct &item) auto type_name = item.get_struct_name ().as_string (); auto fields = SelfOther::fields (builder, item.get_fields ()); - auto fn = eq_fn (build_eq_expression (std::move (fields)), type_name); + auto fn = eq_fn (builder.block (build_eq_expression (std::move (fields))), + type_name); expanded = partialeq_impls (std::move (fn), type_name, item.get_generic_params ()); @@ -270,46 +270,58 @@ DerivePartialEq::visit_enum (Enum &item) auto cases = std::vector (); auto type_name = item.get_identifier ().as_string (); + auto eq_expr_fn = [this] (std::vector &&fields) { + return build_eq_expression (std::move (fields)); + }; + + auto let_sd + = builder.discriminant_value (DerivePartialEq::self_discr, "self"); + auto let_od + = builder.discriminant_value (DerivePartialEq::other_discr, "other"); + + auto discr_cmp + = builder.comparison_expr (builder.identifier (DerivePartialEq::self_discr), + builder.identifier ( + DerivePartialEq::other_discr), + ComparisonOperator::EQUAL); + for (auto &variant : item.get_variants ()) { auto variant_path = builder.variant_path (type_name, variant->get_identifier ().as_string ()); + auto enum_builder = EnumMatchBuilder (variant_path, eq_expr_fn, builder); + switch (variant->get_enum_item_kind ()) { - case EnumItem::Kind::Identifier: - case EnumItem::Kind::Discriminant: - cases.emplace_back (match_enum_identifier (variant_path, variant)); - break; case EnumItem::Kind::Tuple: - cases.emplace_back ( - match_enum_tuple (variant_path, - static_cast (*variant))); + cases.emplace_back (enum_builder.tuple (*variant)); break; case EnumItem::Kind::Struct: - cases.emplace_back ( - match_enum_struct (variant_path, - static_cast (*variant))); + cases.emplace_back (enum_builder.strukt (*variant)); + break; + case EnumItem::Kind::Identifier: + case EnumItem::Kind::Discriminant: + // We don't need to do anything for these, as they are handled by the + // discriminant value comparison break; } } - // NOTE: Mention using discriminant_value and skipping that last case, and - // instead skipping all identifiers/discriminant enum items and returning - // `true` in the wildcard case - // In case the two instances of `Self` don't have the same discriminant, // automatically return false. cases.emplace_back ( - builder.match_case (builder.wildcard (), builder.literal_bool (false))); + builder.match_case (builder.wildcard (), std::move (discr_cmp))); auto match = builder.match (builder.tuple (vec (builder.identifier ("self"), builder.identifier ("other"))), std::move (cases)); - auto fn = eq_fn (std::move (match), type_name); + auto fn = eq_fn (builder.block (vec (std::move (let_sd), std::move (let_od)), + std::move (match)), + type_name); expanded = partialeq_impls (std::move (fn), type_name, item.get_generic_params ()); diff --git a/gcc/rust/expand/rust-derive-partial-eq.h b/gcc/rust/expand/rust-derive-partial-eq.h index fdfe4dacb85..7985414c252 100644 --- a/gcc/rust/expand/rust-derive-partial-eq.h +++ b/gcc/rust/expand/rust-derive-partial-eq.h @@ -44,7 +44,7 @@ private: std::unique_ptr &&eq_fn, std::string name, const std::vector> &type_generics); - std::unique_ptr eq_fn (std::unique_ptr &&cmp_expression, + std::unique_ptr eq_fn (std::unique_ptr &&block, std::string type_name); /** @@ -61,6 +61,9 @@ private: MatchCase match_enum_struct (PathInExpression variant_path, const EnumItemStruct &variant); + constexpr static const char *self_discr = "#self_discr"; + constexpr static const char *other_discr = "#other_discr"; + virtual void visit_struct (StructStruct &item) override; virtual void visit_tuple (TupleStruct &item) override; virtual void visit_enum (Enum &item) override; diff --git a/gcc/testsuite/rust/execute/torture/derive-partialeq2.rs b/gcc/testsuite/rust/execute/torture/derive-partialeq2.rs index 70ed7dcd93d..e316017753a 100644 --- a/gcc/testsuite/rust/execute/torture/derive-partialeq2.rs +++ b/gcc/testsuite/rust/execute/torture/derive-partialeq2.rs @@ -2,6 +2,20 @@ #![feature(intrinsics)] +pub mod core { + pub mod intrinsics { + #[lang = "discriminant_kind"] + pub trait DiscriminantKind { + #[lang = "discriminant_type"] + type Discriminant; + } + + extern "rust-intrinsic" { + pub fn discriminant_value(v: &T) -> ::Discriminant; + } + } +} + #[lang = "sized"] trait Sized {}