]> git.ipfire.org Git - thirdparty/gcc.git/commitdiff
gccrs: add support for lang_item eq and PartialEq trait
authorPhilip Herron <herron.philip@googlemail.com>
Thu, 19 Dec 2024 16:43:49 +0000 (16:43 +0000)
committerArthur Cohen <arthur.cohen@embecosm.com>
Fri, 21 Mar 2025 11:56:55 +0000 (12:56 +0100)
The Eq and Partial Ord are very similar to the operator overloads
we support for add/sub/etc... but they differ in that usually the
function call name matches the name of the lang item. This time
we need to have support to send in a new path for the method call
on the lang item we want instead of just the name of the lang item.

NOTE: this test case doesnt work correctly yet we need to support
the derive of partial eq on enums to generate the correct comparison
code for that.

Fixes Rust-GCC#3302

gcc/rust/ChangeLog:

* backend/rust-compile-expr.cc (CompileExpr::visit): handle partial_eq possible call
* backend/rust-compile-expr.h: handle case where lang item calls differ from name
* hir/tree/rust-hir-expr.cc (OperatorExprMeta::OperatorExprMeta): new helper
* hir/tree/rust-hir-expr.h: likewise
* typecheck/rust-hir-type-check-expr.cc (TypeCheckExpr::visit): handle partial_eq
(TypeCheckExpr::resolve_operator_overload): likewise
* typecheck/rust-hir-type-check-expr.h: likewise
* util/rust-lang-item.cc (LangItem::ComparisonToLangItem): map comparison to lang item
(LangItem::ComparisonToSegment): likewise
* util/rust-lang-item.h: new lang items PartialOrd and Eq
* util/rust-operators.h (enum class): likewise

gcc/testsuite/ChangeLog:

* rust/compile/nr2/exclude: nr2 cant handle this
* rust/compile/cmp1.rs: New test.

Signed-off-by: Philip Herron <herron.philip@googlemail.com>
gcc/rust/backend/rust-compile-expr.cc
gcc/rust/backend/rust-compile-expr.h
gcc/rust/hir/tree/rust-hir-expr.cc
gcc/rust/hir/tree/rust-hir-expr.h
gcc/rust/typecheck/rust-hir-type-check-expr.cc
gcc/rust/typecheck/rust-hir-type-check-expr.h
gcc/rust/util/rust-lang-item.cc
gcc/rust/util/rust-lang-item.h
gcc/rust/util/rust-operators.h
gcc/testsuite/rust/compile/cmp1.rs [new file with mode: 0644]
gcc/testsuite/rust/compile/nr2/exclude

index e0fb1da3feb393324471a2bc31abdb2d85485c32..900e080ea0efd1dcb4bf9aa62e90e9efd8165863 100644 (file)
@@ -279,6 +279,26 @@ CompileExpr::visit (HIR::ComparisonExpr &expr)
   auto rhs = CompileExpr::Compile (expr.get_rhs (), ctx);
   auto location = expr.get_locus ();
 
+  // this might be an operator overload situation lets check
+  TyTy::FnType *fntype;
+  bool is_op_overload = ctx->get_tyctx ()->lookup_operator_overload (
+    expr.get_mappings ().get_hirid (), &fntype);
+  if (is_op_overload)
+    {
+      auto seg_name = LangItem::ComparisonToSegment (expr.get_expr_type ());
+      auto segment = HIR::PathIdentSegment (seg_name);
+      auto lang_item_type
+       = LangItem::ComparisonToLangItem (expr.get_expr_type ());
+
+      rhs = address_expression (rhs, EXPR_LOCATION (rhs));
+
+      translated = resolve_operator_overload (
+       lang_item_type, expr, lhs, rhs, expr.get_lhs (),
+       tl::optional<std::reference_wrapper<HIR::Expr>> (expr.get_rhs ()),
+       segment);
+      return;
+    }
+
   translated = Backend::comparison_expression (op, lhs, rhs, location);
 }
 
@@ -1478,7 +1498,8 @@ CompileExpr::get_receiver_from_dyn (const TyTy::DynamicObjectType *dyn,
 tree
 CompileExpr::resolve_operator_overload (
   LangItem::Kind lang_item_type, HIR::OperatorExprMeta expr, tree lhs, tree rhs,
-  HIR::Expr &lhs_expr, tl::optional<std::reference_wrapper<HIR::Expr>> rhs_expr)
+  HIR::Expr &lhs_expr, tl::optional<std::reference_wrapper<HIR::Expr>> rhs_expr,
+  HIR::PathIdentSegment specified_segment)
 {
   TyTy::FnType *fntype;
   bool is_op_overload = ctx->get_tyctx ()->lookup_operator_overload (
@@ -1499,7 +1520,10 @@ CompileExpr::resolve_operator_overload (
     }
 
   // lookup compiled functions since it may have already been compiled
-  HIR::PathIdentSegment segment_name (LangItem::ToString (lang_item_type));
+  HIR::PathIdentSegment segment_name
+    = specified_segment.is_error ()
+       ? HIR::PathIdentSegment (LangItem::ToString (lang_item_type))
+       : specified_segment;
   tree fn_expr = resolve_method_address (fntype, receiver, expr.get_locus ());
 
   // lookup the autoderef mappings
index b8c4220ded7e0a7cb5520a01d48a58ceecfe43a7..dc78dee3c8f60803e5551c79cb64f28a645a65fe 100644 (file)
@@ -99,7 +99,9 @@ protected:
   tree resolve_operator_overload (
     LangItem::Kind lang_item_type, HIR::OperatorExprMeta expr, tree lhs,
     tree rhs, HIR::Expr &lhs_expr,
-    tl::optional<std::reference_wrapper<HIR::Expr>> rhs_expr);
+    tl::optional<std::reference_wrapper<HIR::Expr>> rhs_expr,
+    HIR::PathIdentSegment specified_segment
+    = HIR::PathIdentSegment::create_error ());
 
   tree compile_bool_literal (const HIR::LiteralExpr &expr,
                             const TyTy::BaseType *tyty);
index 4a902c6559471f4f98115811269ce38141b067f3..2ded789e60b15c49066a3850f088f1e0a73ce1e3 100644 (file)
@@ -1298,6 +1298,12 @@ OperatorExprMeta::OperatorExprMeta (HIR::ArrayIndexExpr &expr)
     locus (expr.get_locus ())
 {}
 
+OperatorExprMeta::OperatorExprMeta (HIR::ComparisonExpr &expr)
+  : node_mappings (expr.get_mappings ()),
+    lvalue_mappings (expr.get_expr ().get_mappings ()),
+    locus (expr.get_locus ())
+{}
+
 AnonConst::AnonConst (NodeId id, std::unique_ptr<Expr> expr)
   : id (id), expr (std::move (expr))
 {
index 1273466b74cd933c77cda5cee70e06b1c044ac19..f8f212884b67f60ea44d570c5f212bda729eaaeb 100644 (file)
@@ -2816,6 +2816,8 @@ public:
 
   OperatorExprMeta (HIR::ArrayIndexExpr &expr);
 
+  OperatorExprMeta (HIR::ComparisonExpr &expr);
+
   const Analysis::NodeMapping &get_mappings () const { return node_mappings; }
 
   const Analysis::NodeMapping &get_lvalue_mappings () const
index 2ea8b4127e69626a43529f7e50d0f51b44b0a36f..7899b1a7943a2c7bca4e653ba3c5fdc83f1c135c 100644 (file)
@@ -344,6 +344,21 @@ TypeCheckExpr::visit (HIR::ComparisonExpr &expr)
   auto lhs = TypeCheckExpr::Resolve (expr.get_lhs ());
   auto rhs = TypeCheckExpr::Resolve (expr.get_rhs ());
 
+  auto borrowed_rhs
+    = new TyTy::ReferenceType (mappings.get_next_hir_id (),
+                              TyTy::TyVar (rhs->get_ref ()), Mutability::Imm);
+  context->insert_implicit_type (borrowed_rhs->get_ref (), borrowed_rhs);
+
+  auto seg_name = LangItem::ComparisonToSegment (expr.get_expr_type ());
+  auto segment = HIR::PathIdentSegment (seg_name);
+  auto lang_item_type = LangItem::ComparisonToLangItem (expr.get_expr_type ());
+
+  bool operator_overloaded
+    = resolve_operator_overload (lang_item_type, expr, lhs, borrowed_rhs,
+                                segment);
+  if (operator_overloaded)
+    return;
+
   unify_site (expr.get_mappings ().get_hirid (),
              TyTy::TyWithLocation (lhs, expr.get_lhs ().get_locus ()),
              TyTy::TyWithLocation (rhs, expr.get_rhs ().get_locus ()),
@@ -1640,10 +1655,10 @@ TypeCheckExpr::visit (HIR::ClosureExpr &expr)
 }
 
 bool
-TypeCheckExpr::resolve_operator_overload (LangItem::Kind lang_item_type,
-                                         HIR::OperatorExprMeta expr,
-                                         TyTy::BaseType *lhs,
-                                         TyTy::BaseType *rhs)
+TypeCheckExpr::resolve_operator_overload (
+  LangItem::Kind lang_item_type, HIR::OperatorExprMeta expr,
+  TyTy::BaseType *lhs, TyTy::BaseType *rhs,
+  HIR::PathIdentSegment specified_segment)
 {
   // look up lang item for arithmetic type
   std::string associated_item_name = LangItem::ToString (lang_item_type);
@@ -1661,7 +1676,9 @@ TypeCheckExpr::resolve_operator_overload (LangItem::Kind lang_item_type,
       current_context = context->peek_context ();
     }
 
-  auto segment = HIR::PathIdentSegment (associated_item_name);
+  auto segment = specified_segment.is_error ()
+                  ? HIR::PathIdentSegment (associated_item_name)
+                  : specified_segment;
   auto candidates = MethodResolver::Probe (lhs, segment);
 
   // remove any recursive candidates
index 82f421e326afee0a0883dabd5bfb47c1a4d26321..2a0022ce7018f55f278c7ca469482a8683cca6a5 100644 (file)
@@ -97,7 +97,9 @@ public:
 protected:
   bool resolve_operator_overload (LangItem::Kind lang_item_type,
                                  HIR::OperatorExprMeta expr,
-                                 TyTy::BaseType *lhs, TyTy::BaseType *rhs);
+                                 TyTy::BaseType *lhs, TyTy::BaseType *rhs,
+                                 HIR::PathIdentSegment specified_segment
+                                 = HIR::PathIdentSegment::create_error ());
 
   bool resolve_fn_trait_call (HIR::CallExpr &expr,
                              TyTy::BaseType *function_tyty,
index c4c1d1c093ab2eed76183680aa28bdee84cd280d..4a609096144358ef435dd1bf639d245be68c7a15 100644 (file)
@@ -98,6 +98,9 @@ const BiMap<std::string, LangItem::Kind> Rust::LangItem::lang_items = {{
 
   {"into_iter", Kind::INTOITER_INTOITER},
   {"next", Kind::ITERATOR_NEXT},
+
+  {"eq", Kind::EQ},
+  {"partial_ord", Kind::PARTIAL_ORD},
 }};
 
 tl::optional<LangItem::Kind>
@@ -145,6 +148,47 @@ LangItem::OperatorToLangItem (ArithmeticOrLogicalOperator op)
   rust_unreachable ();
 }
 
+LangItem::Kind
+LangItem::ComparisonToLangItem (ComparisonOperator op)
+{
+  switch (op)
+    {
+    case ComparisonOperator::NOT_EQUAL:
+    case ComparisonOperator::EQUAL:
+      return LangItem::Kind::EQ;
+
+    case ComparisonOperator::GREATER_THAN:
+    case ComparisonOperator::LESS_THAN:
+    case ComparisonOperator::GREATER_OR_EQUAL:
+    case ComparisonOperator::LESS_OR_EQUAL:
+      return LangItem::Kind::PARTIAL_ORD;
+    }
+
+  rust_unreachable ();
+}
+
+std::string
+LangItem::ComparisonToSegment (ComparisonOperator op)
+{
+  switch (op)
+    {
+    case ComparisonOperator::NOT_EQUAL:
+      return "ne";
+    case ComparisonOperator::EQUAL:
+      return "eq";
+    case ComparisonOperator::GREATER_THAN:
+      return "gt";
+    case ComparisonOperator::LESS_THAN:
+      return "lt";
+    case ComparisonOperator::GREATER_OR_EQUAL:
+      return "ge";
+    case ComparisonOperator::LESS_OR_EQUAL:
+      return "le";
+    }
+
+  rust_unreachable ();
+}
+
 LangItem::Kind
 LangItem::CompoundAssignmentOperatorToLangItem (ArithmeticOrLogicalOperator op)
 {
index 9e432e2ccc6e7ca03a276abf2f9c777bd3b009bb..62b15d7b3fc54d513b1275c70184d8f471041ae3 100644 (file)
@@ -45,6 +45,8 @@ public:
 
     NEGATION,
     NOT,
+    EQ,
+    PARTIAL_ORD,
 
     ADD_ASSIGN,
     SUB_ASSIGN,
@@ -136,6 +138,9 @@ public:
   static Kind
   CompoundAssignmentOperatorToLangItem (ArithmeticOrLogicalOperator op);
   static Kind NegationOperatorToLangItem (NegationOperator op);
+  static Kind ComparisonToLangItem (ComparisonOperator op);
+
+  static std::string ComparisonToSegment (ComparisonOperator op);
 };
 
 } // namespace Rust
index 608e771828af1783309d0ebb97b310367e9c16b3..02b1820b28cdf3e18752dede8ee4fbe9871810a7 100644 (file)
@@ -43,10 +43,10 @@ enum class ComparisonOperator
 {
   EQUAL,           // std::cmp::PartialEq::eq
   NOT_EQUAL,       // std::cmp::PartialEq::ne
-  GREATER_THAN,            // std::cmp::PartialEq::gt
-  LESS_THAN,       // std::cmp::PartialEq::lt
-  GREATER_OR_EQUAL, // std::cmp::PartialEq::ge
-  LESS_OR_EQUAL            // std::cmp::PartialEq::le
+  GREATER_THAN,            // std::cmp::PartialOrd::gt
+  LESS_THAN,       // std::cmp::PartialOrd::lt
+  GREATER_OR_EQUAL, // std::cmp::PartialOrd::ge
+  LESS_OR_EQUAL            // std::cmp::PartialOrd::le
 };
 
 enum class LazyBooleanOperator
diff --git a/gcc/testsuite/rust/compile/cmp1.rs b/gcc/testsuite/rust/compile/cmp1.rs
new file mode 100644 (file)
index 0000000..4da5b1c
--- /dev/null
@@ -0,0 +1,78 @@
+// { dg-options "-w" }
+// taken from https://github.com/rust-lang/rust/blob/e1884a8e3c3e813aada8254edfa120e85bf5ffca/library/core/src/cmp.rs#L98
+
+#[lang = "sized"]
+pub trait Sized {}
+
+#[lang = "eq"]
+#[stable(feature = "rust1", since = "1.0.0")]
+#[doc(alias = "==")]
+#[doc(alias = "!=")]
+pub trait PartialEq<Rhs: ?Sized = Self> {
+    /// This method tests for `self` and `other` values to be equal, and is used
+    /// by `==`.
+    #[must_use]
+    #[stable(feature = "rust1", since = "1.0.0")]
+    fn eq(&self, other: &Rhs) -> bool;
+
+    /// This method tests for `!=`.
+    #[inline]
+    #[must_use]
+    #[stable(feature = "rust1", since = "1.0.0")]
+    fn ne(&self, other: &Rhs) -> bool {
+        !self.eq(other)
+    }
+}
+
+enum BookFormat {
+    Paperback,
+    Hardback,
+    Ebook,
+}
+
+impl PartialEq<BookFormat> for BookFormat {
+    fn eq(&self, other: &BookFormat) -> bool {
+        self == other
+    }
+}
+
+pub struct Book {
+    isbn: i32,
+    format: BookFormat,
+}
+
+// Implement <Book> == <BookFormat> comparisons
+impl PartialEq<BookFormat> for Book {
+    fn eq(&self, other: &BookFormat) -> bool {
+        self.format == *other
+    }
+}
+
+// Implement <BookFormat> == <Book> comparisons
+impl PartialEq<Book> for BookFormat {
+    fn eq(&self, other: &Book) -> bool {
+        *self == other.format
+    }
+}
+
+// Implement <Book> == <Book> comparisons
+impl PartialEq<Book> for Book {
+    fn eq(&self, other: &Book) -> bool {
+        self.isbn == other.isbn
+    }
+}
+
+pub fn main() {
+    let b1 = Book {
+        isbn: 1,
+        format: BookFormat::Paperback,
+    };
+    let b2 = Book {
+        isbn: 2,
+        format: BookFormat::Paperback,
+    };
+
+    let _c1: bool = b1 == BookFormat::Paperback;
+    let _c2: bool = BookFormat::Paperback == b2;
+    let _c3: bool = b1 != b2;
+}
index e7344ed0d5970007be8ea2b94f5ab5172afd8da2..af7d105debc3e4cb707246d58a479c5955ec5d0b 100644 (file)
@@ -196,4 +196,5 @@ additional-trait-bounds2.rs
 auto_traits2.rs
 auto_traits3.rs
 issue-3140.rs
+cmp1.rs
 # please don't delete the trailing newline