]> git.ipfire.org Git - thirdparty/gcc.git/blob - gcc/rust/typecheck/rust-tyty-cmp.h
Update copyright years.
[thirdparty/gcc.git] / gcc / rust / typecheck / rust-tyty-cmp.h
1 // Copyright (C) 2020-2024 Free Software Foundation, Inc.
2
3 // This file is part of GCC.
4
5 // GCC is free software; you can redistribute it and/or modify it under
6 // the terms of the GNU General Public License as published by the Free
7 // Software Foundation; either version 3, or (at your option) any later
8 // version.
9
10 // GCC is distributed in the hope that it will be useful, but WITHOUT ANY
11 // WARRANTY; without even the implied warranty of MERCHANTABILITY or
12 // FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
13 // for more details.
14
15 // You should have received a copy of the GNU General Public License
16 // along with GCC; see the file COPYING3. If not see
17 // <http://www.gnu.org/licenses/>.
18
19 #ifndef RUST_TYTY_CMP_H
20 #define RUST_TYTY_CMP_H
21
22 #include "rust-diagnostics.h"
23 #include "rust-tyty.h"
24 #include "rust-tyty-visitor.h"
25 #include "rust-hir-map.h"
26 #include "rust-hir-type-check.h"
27
28 namespace Rust {
29 namespace TyTy {
30
31 // we need to fix this properly by implementing the match for assembling
32 // candidates
33 extern bool autoderef_cmp_flag;
34
35 class BaseCmp : public TyConstVisitor
36 {
37 public:
38 virtual bool can_eq (const BaseType *other)
39 {
40 if (other->get_kind () == TypeKind::PARAM)
41 {
42 const ParamType *p = static_cast<const ParamType *> (other);
43 other = p->resolve ();
44 }
45 if (other->get_kind () == TypeKind::PLACEHOLDER)
46 {
47 const PlaceholderType *p = static_cast<const PlaceholderType *> (other);
48 if (p->can_resolve ())
49 {
50 other = p->resolve ();
51 }
52 }
53 if (other->get_kind () == TypeKind::PROJECTION)
54 {
55 const ProjectionType *p = static_cast<const ProjectionType *> (other);
56 other = p->get ();
57 }
58
59 other->accept_vis (*this);
60 return ok;
61 }
62
63 virtual void visit (const TupleType &type) override
64 {
65 ok = false;
66
67 if (emit_error_flag)
68 {
69 Location ref_locus = mappings->lookup_location (type.get_ref ());
70 Location base_locus
71 = mappings->lookup_location (get_base ()->get_ref ());
72 RichLocation r (ref_locus);
73 r.add_range (base_locus);
74 rust_error_at (r, "expected [%s] got [%s]",
75 get_base ()->as_string ().c_str (),
76 type.as_string ().c_str ());
77 }
78 }
79
80 virtual void visit (const ADTType &type) override
81 {
82 ok = false;
83 if (emit_error_flag)
84 {
85 Location ref_locus = mappings->lookup_location (type.get_ref ());
86 Location base_locus
87 = mappings->lookup_location (get_base ()->get_ref ());
88 RichLocation r (ref_locus);
89 r.add_range (base_locus);
90 rust_error_at (r, "expected [%s] got [%s]",
91 get_base ()->as_string ().c_str (),
92 type.as_string ().c_str ());
93 }
94 }
95
96 virtual void visit (const InferType &type) override
97 {
98 ok = false;
99 if (emit_error_flag)
100 {
101 Location ref_locus = mappings->lookup_location (type.get_ref ());
102 Location base_locus
103 = mappings->lookup_location (get_base ()->get_ref ());
104 RichLocation r (ref_locus);
105 r.add_range (base_locus);
106 rust_error_at (r, "expected [%s] got [%s]",
107 get_base ()->as_string ().c_str (),
108 type.as_string ().c_str ());
109 }
110 }
111
112 virtual void visit (const FnType &type) override
113 {
114 ok = false;
115 if (emit_error_flag)
116 {
117 Location ref_locus = mappings->lookup_location (type.get_ref ());
118 Location base_locus
119 = mappings->lookup_location (get_base ()->get_ref ());
120 RichLocation r (ref_locus);
121 r.add_range (base_locus);
122 rust_error_at (r, "expected [%s] got [%s]",
123 get_base ()->as_string ().c_str (),
124 type.as_string ().c_str ());
125 }
126 }
127
128 virtual void visit (const FnPtr &type) override
129 {
130 ok = false;
131 if (emit_error_flag)
132 {
133 Location ref_locus = mappings->lookup_location (type.get_ref ());
134 Location base_locus
135 = mappings->lookup_location (get_base ()->get_ref ());
136 RichLocation r (ref_locus);
137 r.add_range (base_locus);
138 rust_error_at (r, "expected [%s] got [%s]",
139 get_base ()->as_string ().c_str (),
140 type.as_string ().c_str ());
141 }
142 }
143
144 virtual void visit (const ArrayType &type) override
145 {
146 ok = false;
147 if (emit_error_flag)
148 {
149 Location ref_locus = mappings->lookup_location (type.get_ref ());
150 Location base_locus
151 = mappings->lookup_location (get_base ()->get_ref ());
152 RichLocation r (ref_locus);
153 r.add_range (base_locus);
154 rust_error_at (r, "expected [%s] got [%s]",
155 get_base ()->as_string ().c_str (),
156 type.as_string ().c_str ());
157 }
158 }
159
160 virtual void visit (const SliceType &type) override
161 {
162 ok = false;
163 if (emit_error_flag)
164 {
165 Location ref_locus = mappings->lookup_location (type.get_ref ());
166 Location base_locus
167 = mappings->lookup_location (get_base ()->get_ref ());
168 RichLocation r (ref_locus);
169 r.add_range (base_locus);
170 rust_error_at (r, "expected [%s] got [%s]",
171 get_base ()->as_string ().c_str (),
172 type.as_string ().c_str ());
173 }
174 }
175
176 virtual void visit (const BoolType &type) override
177 {
178 ok = false;
179 if (emit_error_flag)
180 {
181 Location ref_locus = mappings->lookup_location (type.get_ref ());
182 Location base_locus
183 = mappings->lookup_location (get_base ()->get_ref ());
184 RichLocation r (ref_locus);
185 r.add_range (base_locus);
186 rust_error_at (r, "expected [%s] got [%s]",
187 get_base ()->as_string ().c_str (),
188 type.as_string ().c_str ());
189 }
190 }
191
192 virtual void visit (const IntType &type) override
193 {
194 ok = false;
195 if (emit_error_flag)
196 {
197 Location ref_locus = mappings->lookup_location (type.get_ref ());
198 Location base_locus
199 = mappings->lookup_location (get_base ()->get_ref ());
200 RichLocation r (ref_locus);
201 r.add_range (base_locus);
202 rust_error_at (r, "expected [%s] got [%s]",
203 get_base ()->as_string ().c_str (),
204 type.as_string ().c_str ());
205 }
206 }
207
208 virtual void visit (const UintType &type) override
209 {
210 ok = false;
211 if (emit_error_flag)
212 {
213 Location ref_locus = mappings->lookup_location (type.get_ref ());
214 Location base_locus
215 = mappings->lookup_location (get_base ()->get_ref ());
216 RichLocation r (ref_locus);
217 r.add_range (base_locus);
218 rust_error_at (r, "expected [%s] got [%s]",
219 get_base ()->as_string ().c_str (),
220 type.as_string ().c_str ());
221 }
222 }
223
224 virtual void visit (const USizeType &type) override
225 {
226 ok = false;
227 if (emit_error_flag)
228 {
229 Location ref_locus = mappings->lookup_location (type.get_ref ());
230 Location base_locus
231 = mappings->lookup_location (get_base ()->get_ref ());
232 RichLocation r (ref_locus);
233 r.add_range (base_locus);
234 rust_error_at (r, "expected [%s] got [%s]",
235 get_base ()->as_string ().c_str (),
236 type.as_string ().c_str ());
237 }
238 }
239
240 virtual void visit (const ISizeType &type) override
241 {
242 ok = false;
243 if (emit_error_flag)
244 {
245 Location ref_locus = mappings->lookup_location (type.get_ref ());
246 Location base_locus
247 = mappings->lookup_location (get_base ()->get_ref ());
248 RichLocation r (ref_locus);
249 r.add_range (base_locus);
250 rust_error_at (r, "expected [%s] got [%s]",
251 get_base ()->as_string ().c_str (),
252 type.as_string ().c_str ());
253 }
254 }
255
256 virtual void visit (const FloatType &type) override
257 {
258 ok = false;
259 if (emit_error_flag)
260 {
261 Location ref_locus = mappings->lookup_location (type.get_ref ());
262 Location base_locus
263 = mappings->lookup_location (get_base ()->get_ref ());
264 RichLocation r (ref_locus);
265 r.add_range (base_locus);
266 rust_error_at (r, "expected [%s] got [%s]",
267 get_base ()->as_string ().c_str (),
268 type.as_string ().c_str ());
269 }
270 }
271
272 virtual void visit (const ErrorType &type) override
273 {
274 ok = false;
275 if (emit_error_flag)
276 {
277 Location ref_locus = mappings->lookup_location (type.get_ref ());
278 Location base_locus
279 = mappings->lookup_location (get_base ()->get_ref ());
280 RichLocation r (ref_locus);
281 r.add_range (base_locus);
282 rust_error_at (r, "expected [%s] got [%s]",
283 get_base ()->as_string ().c_str (),
284 type.as_string ().c_str ());
285 }
286 }
287
288 virtual void visit (const CharType &type) override
289 {
290 ok = false;
291 if (emit_error_flag)
292 {
293 Location ref_locus = mappings->lookup_location (type.get_ref ());
294 Location base_locus
295 = mappings->lookup_location (get_base ()->get_ref ());
296 RichLocation r (ref_locus);
297 r.add_range (base_locus);
298 rust_error_at (r, "expected [%s] got [%s]",
299 get_base ()->as_string ().c_str (),
300 type.as_string ().c_str ());
301 }
302 }
303
304 virtual void visit (const ReferenceType &type) override
305 {
306 ok = false;
307 if (emit_error_flag)
308 {
309 Location ref_locus = mappings->lookup_location (type.get_ref ());
310 Location base_locus
311 = mappings->lookup_location (get_base ()->get_ref ());
312 RichLocation r (ref_locus);
313 r.add_range (base_locus);
314 rust_error_at (r, "expected [%s] got [%s]",
315 get_base ()->as_string ().c_str (),
316 type.as_string ().c_str ());
317 }
318 }
319
320 virtual void visit (const PointerType &type) override
321 {
322 ok = false;
323 if (emit_error_flag)
324 {
325 Location ref_locus = mappings->lookup_location (type.get_ref ());
326 Location base_locus
327 = mappings->lookup_location (get_base ()->get_ref ());
328 RichLocation r (ref_locus);
329 r.add_range (base_locus);
330 rust_error_at (r, "expected [%s] got [%s]",
331 get_base ()->as_string ().c_str (),
332 type.as_string ().c_str ());
333 }
334 }
335
336 virtual void visit (const StrType &type) override
337 {
338 ok = false;
339 if (emit_error_flag)
340 {
341 Location ref_locus = mappings->lookup_location (type.get_ref ());
342 Location base_locus
343 = mappings->lookup_location (get_base ()->get_ref ());
344 RichLocation r (ref_locus);
345 r.add_range (base_locus);
346 rust_error_at (r, "expected [%s] got [%s]",
347 get_base ()->as_string ().c_str (),
348 type.as_string ().c_str ());
349 }
350 }
351
352 virtual void visit (const NeverType &type) override
353 {
354 ok = false;
355 if (emit_error_flag)
356 {
357 Location ref_locus = mappings->lookup_location (type.get_ref ());
358 Location base_locus
359 = mappings->lookup_location (get_base ()->get_ref ());
360 RichLocation r (ref_locus);
361 r.add_range (base_locus);
362 rust_error_at (r, "expected [%s] got [%s]",
363 get_base ()->as_string ().c_str (),
364 type.as_string ().c_str ());
365 }
366 }
367
368 virtual void visit (const ProjectionType &type) override
369 {
370 ok = false;
371 if (emit_error_flag)
372 {
373 Location ref_locus = mappings->lookup_location (type.get_ref ());
374 Location base_locus
375 = mappings->lookup_location (get_base ()->get_ref ());
376 RichLocation r (ref_locus);
377 r.add_range (base_locus);
378 rust_error_at (r, "expected [%s] got [%s]",
379 get_base ()->as_string ().c_str (),
380 type.as_string ().c_str ());
381 }
382 }
383
384 virtual void visit (const PlaceholderType &type) override
385 {
386 // it is ok for types to can eq to a placeholder
387 ok = true;
388 }
389
390 virtual void visit (const ParamType &type) override
391 {
392 ok = false;
393 if (emit_error_flag)
394 {
395 Location ref_locus = mappings->lookup_location (type.get_ref ());
396 Location base_locus
397 = mappings->lookup_location (get_base ()->get_ref ());
398 RichLocation r (ref_locus);
399 r.add_range (base_locus);
400 rust_error_at (r, "expected [%s] got [%s]",
401 get_base ()->as_string ().c_str (),
402 type.as_string ().c_str ());
403 }
404 }
405
406 virtual void visit (const DynamicObjectType &type) override
407 {
408 ok = false;
409 if (emit_error_flag)
410 {
411 Location ref_locus = mappings->lookup_location (type.get_ref ());
412 Location base_locus
413 = mappings->lookup_location (get_base ()->get_ref ());
414 RichLocation r (ref_locus);
415 r.add_range (base_locus);
416 rust_error_at (r, "expected [%s] got [%s]",
417 get_base ()->as_string ().c_str (),
418 type.as_string ().c_str ());
419 }
420 }
421
422 virtual void visit (const ClosureType &type) override
423 {
424 ok = false;
425 if (emit_error_flag)
426 {
427 Location ref_locus = mappings->lookup_location (type.get_ref ());
428 Location base_locus
429 = mappings->lookup_location (get_base ()->get_ref ());
430 RichLocation r (ref_locus);
431 r.add_range (base_locus);
432 rust_error_at (r, "expected [%s] got [%s]",
433 get_base ()->as_string ().c_str (),
434 type.as_string ().c_str ());
435 }
436 }
437
438 protected:
439 BaseCmp (const BaseType *base, bool emit_errors)
440 : mappings (Analysis::Mappings::get ()),
441 context (Resolver::TypeCheckContext::get ()), ok (false),
442 emit_error_flag (emit_errors)
443 {}
444
445 Analysis::Mappings *mappings;
446 Resolver::TypeCheckContext *context;
447
448 bool ok;
449 bool emit_error_flag;
450
451 private:
452 /* Returns a pointer to the ty that created this rule. */
453 virtual const BaseType *get_base () const = 0;
454 };
455
456 class InferCmp : public BaseCmp
457 {
458 using Rust::TyTy::BaseCmp::visit;
459
460 public:
461 InferCmp (const InferType *base, bool emit_errors)
462 : BaseCmp (base, emit_errors), base (base)
463 {}
464
465 void visit (const BoolType &type) override
466 {
467 bool is_valid
468 = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL);
469 if (is_valid)
470 {
471 ok = true;
472 return;
473 }
474
475 BaseCmp::visit (type);
476 }
477
478 void visit (const IntType &type) override
479 {
480 bool is_valid
481 = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL)
482 || (base->get_infer_kind ()
483 == TyTy::InferType::InferTypeKind::INTEGRAL);
484 if (is_valid)
485 {
486 ok = true;
487 return;
488 }
489
490 BaseCmp::visit (type);
491 }
492
493 void visit (const UintType &type) override
494 {
495 bool is_valid
496 = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL)
497 || (base->get_infer_kind ()
498 == TyTy::InferType::InferTypeKind::INTEGRAL);
499 if (is_valid)
500 {
501 ok = true;
502 return;
503 }
504
505 BaseCmp::visit (type);
506 }
507
508 void visit (const USizeType &type) override
509 {
510 bool is_valid
511 = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL)
512 || (base->get_infer_kind ()
513 == TyTy::InferType::InferTypeKind::INTEGRAL);
514 if (is_valid)
515 {
516 ok = true;
517 return;
518 }
519
520 BaseCmp::visit (type);
521 }
522
523 void visit (const ISizeType &type) override
524 {
525 bool is_valid
526 = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL)
527 || (base->get_infer_kind ()
528 == TyTy::InferType::InferTypeKind::INTEGRAL);
529 if (is_valid)
530 {
531 ok = true;
532 return;
533 }
534
535 BaseCmp::visit (type);
536 }
537
538 void visit (const FloatType &type) override
539 {
540 bool is_valid
541 = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL)
542 || (base->get_infer_kind () == TyTy::InferType::InferTypeKind::FLOAT);
543 if (is_valid)
544 {
545 ok = true;
546 return;
547 }
548
549 BaseCmp::visit (type);
550 }
551
552 void visit (const ArrayType &type) override
553 {
554 bool is_valid
555 = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL);
556 if (is_valid)
557 {
558 ok = true;
559 return;
560 }
561
562 BaseCmp::visit (type);
563 }
564
565 void visit (const SliceType &type) override
566 {
567 bool is_valid
568 = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL);
569 if (is_valid)
570 {
571 ok = true;
572 return;
573 }
574
575 BaseCmp::visit (type);
576 }
577
578 void visit (const ADTType &type) override
579 {
580 bool is_valid
581 = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL);
582 if (is_valid)
583 {
584 ok = true;
585 return;
586 }
587
588 BaseCmp::visit (type);
589 }
590
591 void visit (const TupleType &type) override
592 {
593 bool is_valid
594 = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL);
595 if (is_valid)
596 {
597 ok = true;
598 return;
599 }
600
601 BaseCmp::visit (type);
602 }
603
604 void visit (const InferType &type) override
605 {
606 switch (base->get_infer_kind ())
607 {
608 case InferType::InferTypeKind::GENERAL:
609 ok = true;
610 return;
611
612 case InferType::InferTypeKind::INTEGRAL: {
613 if (type.get_infer_kind () == InferType::InferTypeKind::INTEGRAL)
614 {
615 ok = true;
616 return;
617 }
618 else if (type.get_infer_kind () == InferType::InferTypeKind::GENERAL)
619 {
620 ok = true;
621 return;
622 }
623 }
624 break;
625
626 case InferType::InferTypeKind::FLOAT: {
627 if (type.get_infer_kind () == InferType::InferTypeKind::FLOAT)
628 {
629 ok = true;
630 return;
631 }
632 else if (type.get_infer_kind () == InferType::InferTypeKind::GENERAL)
633 {
634 ok = true;
635 return;
636 }
637 }
638 break;
639 }
640
641 BaseCmp::visit (type);
642 }
643
644 void visit (const CharType &type) override
645 {
646 {
647 bool is_valid
648 = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL);
649 if (is_valid)
650 {
651 ok = true;
652 return;
653 }
654
655 BaseCmp::visit (type);
656 }
657 }
658
659 void visit (const ReferenceType &type) override
660 {
661 bool is_valid
662 = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL);
663 if (is_valid)
664 {
665 ok = true;
666 return;
667 }
668
669 BaseCmp::visit (type);
670 }
671
672 void visit (const PointerType &type) override
673 {
674 bool is_valid
675 = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL);
676 if (is_valid)
677 {
678 ok = true;
679 return;
680 }
681
682 BaseCmp::visit (type);
683 }
684
685 void visit (const ParamType &) override { ok = true; }
686
687 void visit (const DynamicObjectType &type) override
688 {
689 bool is_valid
690 = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL);
691 if (is_valid)
692 {
693 ok = true;
694 return;
695 }
696
697 BaseCmp::visit (type);
698 }
699
700 void visit (const ClosureType &type) override
701 {
702 bool is_valid
703 = (base->get_infer_kind () == TyTy::InferType::InferTypeKind::GENERAL);
704 if (is_valid)
705 {
706 ok = true;
707 return;
708 }
709
710 BaseCmp::visit (type);
711 }
712
713 private:
714 const BaseType *get_base () const override { return base; }
715 const InferType *base;
716 };
717
718 class FnCmp : public BaseCmp
719 {
720 using Rust::TyTy::BaseCmp::visit;
721
722 public:
723 FnCmp (const FnType *base, bool emit_errors)
724 : BaseCmp (base, emit_errors), base (base)
725 {}
726
727 void visit (const InferType &type) override
728 {
729 ok = type.get_infer_kind () == InferType::InferTypeKind::GENERAL;
730 }
731
732 void visit (const FnType &type) override
733 {
734 if (base->num_params () != type.num_params ())
735 {
736 BaseCmp::visit (type);
737 return;
738 }
739
740 for (size_t i = 0; i < base->num_params (); i++)
741 {
742 auto a = base->param_at (i).second;
743 auto b = type.param_at (i).second;
744
745 if (!a->can_eq (b, emit_error_flag))
746 {
747 emit_error_flag = false;
748 BaseCmp::visit (type);
749 return;
750 }
751 }
752
753 if (!base->get_return_type ()->can_eq (type.get_return_type (),
754 emit_error_flag))
755 {
756 emit_error_flag = false;
757 BaseCmp::visit (type);
758 return;
759 }
760
761 ok = true;
762 }
763
764 private:
765 const BaseType *get_base () const override { return base; }
766 const FnType *base;
767 };
768
769 class FnptrCmp : public BaseCmp
770 {
771 using Rust::TyTy::BaseCmp::visit;
772
773 public:
774 FnptrCmp (const FnPtr *base, bool emit_errors)
775 : BaseCmp (base, emit_errors), base (base)
776 {}
777
778 void visit (const InferType &type) override
779 {
780 if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL)
781 {
782 BaseCmp::visit (type);
783 return;
784 }
785
786 ok = true;
787 }
788
789 void visit (const FnPtr &type) override
790 {
791 if (base->num_params () != type.num_params ())
792 {
793 BaseCmp::visit (type);
794 return;
795 }
796
797 auto this_ret_type = base->get_return_type ();
798 auto other_ret_type = type.get_return_type ();
799 if (!this_ret_type->can_eq (other_ret_type, emit_error_flag))
800 {
801 BaseCmp::visit (type);
802 return;
803 }
804
805 for (size_t i = 0; i < base->num_params (); i++)
806 {
807 auto this_param = base->param_at (i);
808 auto other_param = type.param_at (i);
809 if (!this_param->can_eq (other_param, emit_error_flag))
810 {
811 BaseCmp::visit (type);
812 return;
813 }
814 }
815
816 ok = true;
817 }
818
819 void visit (const FnType &type) override
820 {
821 if (base->num_params () != type.num_params ())
822 {
823 BaseCmp::visit (type);
824 return;
825 }
826
827 auto this_ret_type = base->get_return_type ();
828 auto other_ret_type = type.get_return_type ();
829 if (!this_ret_type->can_eq (other_ret_type, emit_error_flag))
830 {
831 BaseCmp::visit (type);
832 return;
833 }
834
835 for (size_t i = 0; i < base->num_params (); i++)
836 {
837 auto this_param = base->param_at (i);
838 auto other_param = type.param_at (i).second;
839 if (!this_param->can_eq (other_param, emit_error_flag))
840 {
841 BaseCmp::visit (type);
842 return;
843 }
844 }
845
846 ok = true;
847 }
848
849 private:
850 const BaseType *get_base () const override { return base; }
851 const FnPtr *base;
852 };
853
854 class ClosureCmp : public BaseCmp
855 {
856 using Rust::TyTy::BaseCmp::visit;
857
858 public:
859 ClosureCmp (const ClosureType *base, bool emit_errors)
860 : BaseCmp (base, emit_errors), base (base)
861 {}
862
863 void visit (const InferType &type) override
864 {
865 if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL)
866 {
867 BaseCmp::visit (type);
868 return;
869 }
870
871 ok = true;
872 }
873
874 void visit (const ClosureType &type) override
875 {
876 if (base->get_def_id () != type.get_def_id ())
877 {
878 BaseCmp::visit (type);
879 return;
880 }
881
882 if (!base->get_parameters ().can_eq (&type.get_parameters (), false))
883 {
884 BaseCmp::visit (type);
885 return;
886 }
887
888 if (!base->get_result_type ().can_eq (&type.get_result_type (), false))
889 {
890 BaseCmp::visit (type);
891 return;
892 }
893
894 ok = true;
895 }
896
897 private:
898 const BaseType *get_base () const override { return base; }
899 const ClosureType *base;
900 };
901
902 class ArrayCmp : public BaseCmp
903 {
904 using Rust::TyTy::BaseCmp::visit;
905
906 public:
907 ArrayCmp (const ArrayType *base, bool emit_errors)
908 : BaseCmp (base, emit_errors), base (base)
909 {}
910
911 void visit (const ArrayType &type) override
912 {
913 // check base type
914 const BaseType *base_element = base->get_element_type ();
915 const BaseType *other_element = type.get_element_type ();
916 if (!base_element->can_eq (other_element, emit_error_flag))
917 {
918 BaseCmp::visit (type);
919 return;
920 }
921
922 ok = true;
923 }
924
925 void visit (const InferType &type) override
926 {
927 if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL)
928 {
929 BaseCmp::visit (type);
930 return;
931 }
932
933 ok = true;
934 }
935
936 private:
937 const BaseType *get_base () const override { return base; }
938 const ArrayType *base;
939 };
940
941 class SliceCmp : public BaseCmp
942 {
943 using Rust::TyTy::BaseCmp::visit;
944
945 public:
946 SliceCmp (const SliceType *base, bool emit_errors)
947 : BaseCmp (base, emit_errors), base (base)
948 {}
949
950 void visit (const SliceType &type) override
951 {
952 // check base type
953 const BaseType *base_element = base->get_element_type ();
954 const BaseType *other_element = type.get_element_type ();
955 if (!base_element->can_eq (other_element, emit_error_flag))
956 {
957 BaseCmp::visit (type);
958 return;
959 }
960
961 ok = true;
962 }
963
964 void visit (const InferType &type) override
965 {
966 if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL)
967 {
968 BaseCmp::visit (type);
969 return;
970 }
971
972 ok = true;
973 }
974
975 private:
976 const BaseType *get_base () const override { return base; }
977 const SliceType *base;
978 };
979
980 class BoolCmp : public BaseCmp
981 {
982 using Rust::TyTy::BaseCmp::visit;
983
984 public:
985 BoolCmp (const BoolType *base, bool emit_errors)
986 : BaseCmp (base, emit_errors), base (base)
987 {}
988
989 void visit (const BoolType &type) override { ok = true; }
990
991 void visit (const InferType &type) override
992 {
993 ok = type.get_infer_kind () == InferType::InferTypeKind::GENERAL;
994 }
995
996 private:
997 const BaseType *get_base () const override { return base; }
998 const BoolType *base;
999 };
1000
1001 class IntCmp : public BaseCmp
1002 {
1003 using Rust::TyTy::BaseCmp::visit;
1004
1005 public:
1006 IntCmp (const IntType *base, bool emit_errors)
1007 : BaseCmp (base, emit_errors), base (base)
1008 {}
1009
1010 void visit (const InferType &type) override
1011 {
1012 ok = type.get_infer_kind () != InferType::InferTypeKind::FLOAT;
1013 }
1014
1015 void visit (const IntType &type) override
1016 {
1017 ok = type.get_int_kind () == base->get_int_kind ();
1018 }
1019
1020 private:
1021 const BaseType *get_base () const override { return base; }
1022 const IntType *base;
1023 };
1024
1025 class UintCmp : public BaseCmp
1026 {
1027 using Rust::TyTy::BaseCmp::visit;
1028
1029 public:
1030 UintCmp (const UintType *base, bool emit_errors)
1031 : BaseCmp (base, emit_errors), base (base)
1032 {}
1033
1034 void visit (const InferType &type) override
1035 {
1036 ok = type.get_infer_kind () != InferType::InferTypeKind::FLOAT;
1037 }
1038
1039 void visit (const UintType &type) override
1040 {
1041 ok = type.get_uint_kind () == base->get_uint_kind ();
1042 }
1043
1044 private:
1045 const BaseType *get_base () const override { return base; }
1046 const UintType *base;
1047 };
1048
1049 class FloatCmp : public BaseCmp
1050 {
1051 using Rust::TyTy::BaseCmp::visit;
1052
1053 public:
1054 FloatCmp (const FloatType *base, bool emit_errors)
1055 : BaseCmp (base, emit_errors), base (base)
1056 {}
1057
1058 void visit (const InferType &type) override
1059 {
1060 ok = type.get_infer_kind () != InferType::InferTypeKind::INTEGRAL;
1061 }
1062
1063 void visit (const FloatType &type) override
1064 {
1065 ok = type.get_float_kind () == base->get_float_kind ();
1066 }
1067
1068 private:
1069 const BaseType *get_base () const override { return base; }
1070 const FloatType *base;
1071 };
1072
1073 class ADTCmp : public BaseCmp
1074 {
1075 using Rust::TyTy::BaseCmp::visit;
1076
1077 public:
1078 ADTCmp (const ADTType *base, bool emit_errors)
1079 : BaseCmp (base, emit_errors), base (base)
1080 {}
1081
1082 void visit (const ADTType &type) override
1083 {
1084 if (base->get_adt_kind () != type.get_adt_kind ())
1085 {
1086 BaseCmp::visit (type);
1087 return;
1088 }
1089
1090 if (base->get_identifier ().compare (type.get_identifier ()) != 0)
1091 {
1092 BaseCmp::visit (type);
1093 return;
1094 }
1095
1096 if (base->number_of_variants () != type.number_of_variants ())
1097 {
1098 BaseCmp::visit (type);
1099 return;
1100 }
1101
1102 for (size_t i = 0; i < type.number_of_variants (); ++i)
1103 {
1104 TyTy::VariantDef *a = base->get_variants ().at (i);
1105 TyTy::VariantDef *b = type.get_variants ().at (i);
1106
1107 if (a->num_fields () != b->num_fields ())
1108 {
1109 BaseCmp::visit (type);
1110 return;
1111 }
1112
1113 for (size_t j = 0; j < a->num_fields (); j++)
1114 {
1115 TyTy::StructFieldType *base_field = a->get_field_at_index (j);
1116 TyTy::StructFieldType *other_field = b->get_field_at_index (j);
1117
1118 TyTy::BaseType *this_field_ty = base_field->get_field_type ();
1119 TyTy::BaseType *other_field_ty = other_field->get_field_type ();
1120
1121 if (!this_field_ty->can_eq (other_field_ty, emit_error_flag))
1122 {
1123 BaseCmp::visit (type);
1124 return;
1125 }
1126 }
1127 }
1128
1129 ok = true;
1130 }
1131
1132 void visit (const InferType &type) override
1133 {
1134 if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL)
1135 {
1136 BaseCmp::visit (type);
1137 return;
1138 }
1139
1140 ok = true;
1141 }
1142
1143 private:
1144 const BaseType *get_base () const override { return base; }
1145 const ADTType *base;
1146 };
1147
1148 class TupleCmp : public BaseCmp
1149 {
1150 using Rust::TyTy::BaseCmp::visit;
1151
1152 public:
1153 TupleCmp (const TupleType *base, bool emit_errors)
1154 : BaseCmp (base, emit_errors), base (base)
1155 {}
1156
1157 void visit (const TupleType &type) override
1158 {
1159 if (base->num_fields () != type.num_fields ())
1160 {
1161 BaseCmp::visit (type);
1162 return;
1163 }
1164
1165 for (size_t i = 0; i < base->num_fields (); i++)
1166 {
1167 BaseType *bo = base->get_field (i);
1168 BaseType *fo = type.get_field (i);
1169
1170 if (!bo->can_eq (fo, emit_error_flag))
1171 {
1172 BaseCmp::visit (type);
1173 return;
1174 }
1175 }
1176
1177 ok = true;
1178 }
1179
1180 void visit (const InferType &type) override
1181 {
1182 if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL)
1183 {
1184 BaseCmp::visit (type);
1185 return;
1186 }
1187
1188 ok = true;
1189 }
1190
1191 private:
1192 const BaseType *get_base () const override { return base; }
1193 const TupleType *base;
1194 };
1195
1196 class USizeCmp : public BaseCmp
1197 {
1198 using Rust::TyTy::BaseCmp::visit;
1199
1200 public:
1201 USizeCmp (const USizeType *base, bool emit_errors)
1202 : BaseCmp (base, emit_errors), base (base)
1203 {}
1204
1205 void visit (const InferType &type) override
1206 {
1207 ok = type.get_infer_kind () != InferType::InferTypeKind::FLOAT;
1208 }
1209
1210 void visit (const USizeType &type) override { ok = true; }
1211
1212 private:
1213 const BaseType *get_base () const override { return base; }
1214 const USizeType *base;
1215 };
1216
1217 class ISizeCmp : public BaseCmp
1218 {
1219 using Rust::TyTy::BaseCmp::visit;
1220
1221 public:
1222 ISizeCmp (const ISizeType *base, bool emit_errors)
1223 : BaseCmp (base, emit_errors), base (base)
1224 {}
1225
1226 void visit (const InferType &type) override
1227 {
1228 ok = type.get_infer_kind () != InferType::InferTypeKind::FLOAT;
1229 }
1230
1231 void visit (const ISizeType &type) override { ok = true; }
1232
1233 private:
1234 const BaseType *get_base () const override { return base; }
1235 const ISizeType *base;
1236 };
1237
1238 class CharCmp : public BaseCmp
1239 {
1240 using Rust::TyTy::BaseCmp::visit;
1241
1242 public:
1243 CharCmp (const CharType *base, bool emit_errors)
1244 : BaseCmp (base, emit_errors), base (base)
1245 {}
1246
1247 void visit (const InferType &type) override
1248 {
1249 ok = type.get_infer_kind () == InferType::InferTypeKind::GENERAL;
1250 }
1251
1252 void visit (const CharType &type) override { ok = true; }
1253
1254 private:
1255 const BaseType *get_base () const override { return base; }
1256 const CharType *base;
1257 };
1258
1259 class ReferenceCmp : public BaseCmp
1260 {
1261 using Rust::TyTy::BaseCmp::visit;
1262
1263 public:
1264 ReferenceCmp (const ReferenceType *base, bool emit_errors)
1265 : BaseCmp (base, emit_errors), base (base)
1266 {}
1267
1268 void visit (const ReferenceType &type) override
1269 {
1270 auto base_type = base->get_base ();
1271 auto other_base_type = type.get_base ();
1272
1273 bool mutability_ok = base->is_mutable () ? type.is_mutable () : true;
1274 if (autoderef_cmp_flag)
1275 mutability_ok = base->mutability () == type.mutability ();
1276
1277 if (!mutability_ok)
1278 {
1279 BaseCmp::visit (type);
1280 return;
1281 }
1282
1283 if (!base_type->can_eq (other_base_type, emit_error_flag))
1284 {
1285 BaseCmp::visit (type);
1286 return;
1287 }
1288
1289 ok = true;
1290 }
1291
1292 void visit (const InferType &type) override
1293 {
1294 if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL)
1295 {
1296 BaseCmp::visit (type);
1297 return;
1298 }
1299
1300 ok = true;
1301 }
1302
1303 private:
1304 const BaseType *get_base () const override { return base; }
1305 const ReferenceType *base;
1306 };
1307
1308 class PointerCmp : public BaseCmp
1309 {
1310 using Rust::TyTy::BaseCmp::visit;
1311
1312 public:
1313 PointerCmp (const PointerType *base, bool emit_errors)
1314 : BaseCmp (base, emit_errors), base (base)
1315 {}
1316
1317 void visit (const PointerType &type) override
1318 {
1319 auto base_type = base->get_base ();
1320 auto other_base_type = type.get_base ();
1321
1322 bool mutability_ok = base->is_mutable () ? type.is_mutable () : true;
1323 if (autoderef_cmp_flag)
1324 mutability_ok = base->mutability () == type.mutability ();
1325
1326 if (!mutability_ok)
1327 {
1328 BaseCmp::visit (type);
1329 return;
1330 }
1331
1332 if (!base_type->can_eq (other_base_type, emit_error_flag))
1333 {
1334 BaseCmp::visit (type);
1335 return;
1336 }
1337
1338 ok = true;
1339 }
1340
1341 void visit (const InferType &type) override
1342 {
1343 if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL)
1344 {
1345 BaseCmp::visit (type);
1346 return;
1347 }
1348
1349 ok = true;
1350 }
1351
1352 private:
1353 const BaseType *get_base () const override { return base; }
1354 const PointerType *base;
1355 };
1356
1357 class ParamCmp : public BaseCmp
1358 {
1359 using Rust::TyTy::BaseCmp::visit;
1360
1361 public:
1362 ParamCmp (const ParamType *base, bool emit_errors)
1363 : BaseCmp (base, emit_errors), base (base)
1364 {}
1365
1366 // param types are a placeholder we shouldn't have cases where we unify
1367 // against it. eg: struct foo<T> { a: T }; When we invoke it we can do either:
1368 //
1369 // foo<i32>{ a: 123 }.
1370 // Then this enforces the i32 type to be referenced on the
1371 // field via an hirid.
1372 //
1373 // rust also allows for a = foo{a:123}; Where we can use an Inference Variable
1374 // to handle the typing of the struct
1375 bool can_eq (const BaseType *other) override
1376 {
1377 if (!base->can_resolve ())
1378 return BaseCmp::can_eq (other);
1379
1380 auto lookup = base->resolve ();
1381 return lookup->can_eq (other, emit_error_flag);
1382 }
1383
1384 // imagine the case where we have:
1385 // struct Foo<T>(T);
1386 // Then we declare a generic impl block
1387 // impl <X>Foo<X> { ... }
1388 // both of these types are compatible so we mostly care about the number of
1389 // generic arguments
1390 void visit (const ParamType &) override { ok = true; }
1391
1392 void visit (const TupleType &) override { ok = true; }
1393
1394 void visit (const InferType &) override { ok = true; }
1395
1396 void visit (const FnType &) override { ok = true; }
1397
1398 void visit (const FnPtr &) override { ok = true; }
1399
1400 void visit (const ADTType &) override { ok = true; }
1401
1402 void visit (const ArrayType &) override { ok = true; }
1403
1404 void visit (const SliceType &) override { ok = !autoderef_cmp_flag; }
1405
1406 void visit (const BoolType &) override { ok = true; }
1407
1408 void visit (const IntType &) override { ok = true; }
1409
1410 void visit (const UintType &) override { ok = true; }
1411
1412 void visit (const USizeType &) override { ok = true; }
1413
1414 void visit (const ISizeType &) override { ok = true; }
1415
1416 void visit (const FloatType &) override { ok = true; }
1417
1418 void visit (const CharType &) override { ok = true; }
1419
1420 void visit (const ReferenceType &) override { ok = true; }
1421
1422 void visit (const PointerType &) override { ok = true; }
1423
1424 void visit (const StrType &) override { ok = true; }
1425
1426 void visit (const NeverType &) override { ok = true; }
1427
1428 void visit (const DynamicObjectType &) override { ok = true; }
1429
1430 void visit (const PlaceholderType &type) override
1431 {
1432 ok = base->get_symbol ().compare (type.get_symbol ()) == 0;
1433 }
1434
1435 private:
1436 const BaseType *get_base () const override { return base; }
1437 const ParamType *base;
1438 };
1439
1440 class StrCmp : public BaseCmp
1441 {
1442 // FIXME we will need a enum for the StrType like ByteBuf etc..
1443 using Rust::TyTy::BaseCmp::visit;
1444
1445 public:
1446 StrCmp (const StrType *base, bool emit_errors)
1447 : BaseCmp (base, emit_errors), base (base)
1448 {}
1449
1450 void visit (const StrType &type) override { ok = true; }
1451
1452 void visit (const InferType &type) override
1453 {
1454 if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL)
1455 {
1456 BaseCmp::visit (type);
1457 return;
1458 }
1459
1460 ok = true;
1461 }
1462
1463 private:
1464 const BaseType *get_base () const override { return base; }
1465 const StrType *base;
1466 };
1467
1468 class NeverCmp : public BaseCmp
1469 {
1470 using Rust::TyTy::BaseCmp::visit;
1471
1472 public:
1473 NeverCmp (const NeverType *base, bool emit_errors)
1474 : BaseCmp (base, emit_errors), base (base)
1475 {}
1476
1477 void visit (const NeverType &type) override { ok = true; }
1478
1479 void visit (const InferType &type) override
1480 {
1481 if (type.get_infer_kind () != InferType::InferTypeKind::GENERAL)
1482 {
1483 BaseCmp::visit (type);
1484 return;
1485 }
1486
1487 ok = true;
1488 }
1489
1490 private:
1491 const BaseType *get_base () const override { return base; }
1492 const NeverType *base;
1493 };
1494
1495 class PlaceholderCmp : public BaseCmp
1496 {
1497 using Rust::TyTy::BaseCmp::visit;
1498
1499 public:
1500 PlaceholderCmp (const PlaceholderType *base, bool emit_errors)
1501 : BaseCmp (base, emit_errors), base (base)
1502 {}
1503
1504 bool can_eq (const BaseType *other) override
1505 {
1506 if (!base->can_resolve ())
1507 return BaseCmp::can_eq (other);
1508
1509 BaseType *lookup = base->resolve ();
1510 return lookup->can_eq (other, emit_error_flag);
1511 }
1512
1513 void visit (const TupleType &) override { ok = true; }
1514
1515 void visit (const ADTType &) override { ok = true; }
1516
1517 void visit (const InferType &) override { ok = true; }
1518
1519 void visit (const FnType &) override { ok = true; }
1520
1521 void visit (const FnPtr &) override { ok = true; }
1522
1523 void visit (const ArrayType &) override { ok = true; }
1524
1525 void visit (const BoolType &) override { ok = true; }
1526
1527 void visit (const IntType &) override { ok = true; }
1528
1529 void visit (const UintType &) override { ok = true; }
1530
1531 void visit (const USizeType &) override { ok = true; }
1532
1533 void visit (const ISizeType &) override { ok = true; }
1534
1535 void visit (const FloatType &) override { ok = true; }
1536
1537 void visit (const ErrorType &) override { ok = true; }
1538
1539 void visit (const CharType &) override { ok = true; }
1540
1541 void visit (const ReferenceType &) override { ok = true; }
1542
1543 void visit (const ParamType &) override { ok = true; }
1544
1545 void visit (const StrType &) override { ok = true; }
1546
1547 void visit (const NeverType &) override { ok = true; }
1548
1549 void visit (const SliceType &) override { ok = true; }
1550
1551 private:
1552 const BaseType *get_base () const override { return base; }
1553
1554 const PlaceholderType *base;
1555 };
1556
1557 class DynamicCmp : public BaseCmp
1558 {
1559 using Rust::TyTy::BaseCmp::visit;
1560
1561 public:
1562 DynamicCmp (const DynamicObjectType *base, bool emit_errors)
1563 : BaseCmp (base, emit_errors), base (base)
1564 {}
1565
1566 void visit (const DynamicObjectType &type) override
1567 {
1568 if (base->num_specified_bounds () != type.num_specified_bounds ())
1569 {
1570 BaseCmp::visit (type);
1571 return;
1572 }
1573
1574 Location ref_locus = mappings->lookup_location (type.get_ref ());
1575 ok = base->bounds_compatible (type, ref_locus, false);
1576 }
1577
1578 private:
1579 const BaseType *get_base () const override { return base; }
1580
1581 const DynamicObjectType *base;
1582 };
1583
1584 } // namespace TyTy
1585 } // namespace Rust
1586
1587 #endif // RUST_TYTY_CMP_H