]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
ensure ClassVar succeeds in cleanup_mapped_str_annotation
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 13 Oct 2023 19:27:40 +0000 (15:27 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 13 Oct 2023 19:28:42 +0000 (15:28 -0400)
Fixed bug in ORM annotated declarative where using a ``ClassVar`` that
nonetheless referred in some way to an ORM mapped class name would fail to
be interpreted as a ``ClassVar`` that's not mapped.

Fixes: #10472
Change-Id: I6606b0f0222ef088e594eb3b0c0653d983d6ff89

doc/build/changelog/unreleased_20/10472.rst [new file with mode: 0644]
lib/sqlalchemy/orm/util.py
test/orm/declarative/test_tm_future_annotations.py

diff --git a/doc/build/changelog/unreleased_20/10472.rst b/doc/build/changelog/unreleased_20/10472.rst
new file mode 100644 (file)
index 0000000..be84b2b
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 10472
+
+    Fixed bug in ORM annotated declarative where using a ``ClassVar`` that
+    nonetheless referred in some way to an ORM mapped class name would fail to
+    be interpreted as a ``ClassVar`` that's not mapped.
index feb82a648ea8d4aff0ff6d7da78c1801cf50c4a4..ea2f1a12e9365d562aace63b6c2a4e18476c5813 100644 (file)
@@ -2250,14 +2250,17 @@ def _cleanup_mapped_str_annotation(
             "outside of TYPE_CHECKING blocks"
         ) from ne
 
-    try:
-        if issubclass(obj, _MappedAnnotationBase):
-            real_symbol = obj.__name__
-        else:
+    if obj is typing.ClassVar:
+        real_symbol = "ClassVar"
+    else:
+        try:
+            if issubclass(obj, _MappedAnnotationBase):
+                real_symbol = obj.__name__
+            else:
+                return annotation
+        except TypeError:
+            # avoid isinstance(obj, type) check, just catch TypeError
             return annotation
-    except TypeError:
-        # avoid isinstance(obj, type) check, just catch TypeError
-        return annotation
 
     # note: if one of the codepaths above didn't define real_symbol and
     # then didn't return, real_symbol raises UnboundLocalError
index 1677cdbb9bde2cd1bd8617df111a79c3c5a8a101..833518a42756d77d0820e0cc952d54e9c1ec7562 100644 (file)
@@ -8,15 +8,20 @@ the ``test_tm_future_annotations_sync`` by the ``sync_test_file`` script.
 
 from __future__ import annotations
 
+from typing import ClassVar
+from typing import Dict
 from typing import List
+from typing import Optional
 from typing import TYPE_CHECKING
 from typing import TypeVar
 import uuid
 
 from sqlalchemy import exc
 from sqlalchemy import ForeignKey
+from sqlalchemy import inspect
 from sqlalchemy import Integer
 from sqlalchemy import select
+from sqlalchemy import testing
 from sqlalchemy import Uuid
 import sqlalchemy.orm
 from sqlalchemy.orm import attribute_keyed_dict
@@ -181,6 +186,41 @@ class MappedColumnTest(_MappedColumnTest):
                 id: Mapped[int] = mapped_column(primary_key=True)
                 data: Mapped[fake]  # noqa
 
+    @testing.variation(
+        "reference_type",
+        [
+            "plain",
+            "plain_optional",
+            "container_w_local_mapped",
+            "container_w_remote_mapped",
+        ],
+    )
+    def test_i_have_a_classvar_on_my_class(self, decl_base, reference_type):
+        if reference_type.container_w_remote_mapped:
+
+            class MyOtherClass(decl_base):
+                __tablename__ = "myothertable"
+
+                id: Mapped[int] = mapped_column(primary_key=True)
+
+        class MyClass(decl_base):
+            __tablename__ = "mytable"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[str] = mapped_column(default="some default")
+
+            if reference_type.container_w_remote_mapped:
+                status: ClassVar[Dict[str, MyOtherClass]]
+            elif reference_type.container_w_local_mapped:
+                status: ClassVar[Dict[str, MyClass]]
+            elif reference_type.plain_optional:
+                status: ClassVar[Optional[int]]
+            elif reference_type.plain:
+                status: ClassVar[int]
+
+        m1 = MyClass(id=1, data=5)
+        assert "status" not in inspect(m1).mapper.attrs
+
 
 class MappedOneArg(KeyFuncDict[str, _R]):
     pass