]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
include cls locals in annotation evaluate
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 18 Jan 2024 17:47:02 +0000 (12:47 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 6 Feb 2024 00:38:04 +0000 (19:38 -0500)
Fixed issue where it was not possible to use a type (such as an enum)
within a :class:`_orm.Mapped` container type if that type were declared
locally within the class body.  The scope of locals used for the eval now
includes that of the class body itself.  In addition, the expression within
:class:`_orm.Mapped` may also refer to the class name itself, if used as a
string or with future annotations mode.

Fixes: #10899
Change-Id: Id4d07499558e457e63b483ff44c0972d9265409d

doc/build/changelog/unreleased_20/10899.rst [new file with mode: 0644]
lib/sqlalchemy/util/typing.py
test/orm/declarative/test_tm_future_annotations.py
test/orm/declarative/test_tm_future_annotations_sync.py
test/orm/declarative/test_typed_mapping.py

diff --git a/doc/build/changelog/unreleased_20/10899.rst b/doc/build/changelog/unreleased_20/10899.rst
new file mode 100644 (file)
index 0000000..6923813
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 10899
+
+    Fixed issue where it was not possible to use a type (such as an enum)
+    within a :class:`_orm.Mapped` container type if that type were declared
+    locally within the class body.  The scope of locals used for the eval now
+    includes that of the class body itself.  In addition, the expression within
+    :class:`_orm.Mapped` may also refer to the class name itself, if used as a
+    string or with future annotations mode.
index ce3aa9fe321b2260c492012d2ed6a59a70fba6e1..1940beac57786b664eed3d65236a420151cf045d 100644 (file)
@@ -153,7 +153,7 @@ def de_stringify_annotation(
             annotation = str_cleanup_fn(annotation, originating_module)
 
         annotation = eval_expression(
-            annotation, originating_module, locals_=locals_
+            annotation, originating_module, locals_=locals_, in_class=cls
         )
 
     if (
@@ -206,6 +206,7 @@ def eval_expression(
     module_name: str,
     *,
     locals_: Optional[Mapping[str, Any]] = None,
+    in_class: Optional[Type[Any]] = None,
 ) -> Any:
     try:
         base_globals: Dict[str, Any] = sys.modules[module_name].__dict__
@@ -216,7 +217,18 @@ def eval_expression(
         ) from ke
 
     try:
-        annotation = eval(expression, base_globals, locals_)
+        if in_class is not None:
+            cls_namespace = dict(in_class.__dict__)
+            cls_namespace.setdefault(in_class.__name__, in_class)
+
+            # see #10899.  We want the locals/globals to take precedence
+            # over the class namespace in this context, even though this
+            # is not the usual way variables would resolve.
+            cls_namespace.update(base_globals)
+
+            annotation = eval(expression, cls_namespace, locals_)
+        else:
+            annotation = eval(expression, base_globals, locals_)
     except Exception as err:
         raise NameError(
             f"Could not de-stringify annotation {expression!r}"
index 833518a42756d77d0820e0cc952d54e9c1ec7562..e3b5df0ad48ca1eaa3042853905a1d30b166413b 100644 (file)
@@ -8,6 +8,7 @@ the ``test_tm_future_annotations_sync`` by the ``sync_test_file`` script.
 
 from __future__ import annotations
 
+import enum
 from typing import ClassVar
 from typing import Dict
 from typing import List
@@ -29,8 +30,11 @@ from sqlalchemy.orm import KeyFuncDict
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import relationship
+from sqlalchemy.sql import sqltypes
+from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_true
 from .test_typed_mapping import expect_annotation_syntax_error
 from .test_typed_mapping import MappedColumnTest as _MappedColumnTest
 from .test_typed_mapping import RelationshipLHSTest as _RelationshipLHSTest
@@ -112,6 +116,85 @@ class MappedColumnTest(_MappedColumnTest):
             select(Foo), "SELECT foo.id, foo.data, foo.data2 FROM foo"
         )
 
+    def test_type_favors_outer(self, decl_base):
+        """test #10899, that we maintain favoring outer names vs. inner.
+        this is for backwards compatibility as well as what people
+        usually expect regarding the names of attributes in the class.
+
+        """
+
+        class User(decl_base):
+            __tablename__ = "user"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            uuid: Mapped[uuid.UUID] = mapped_column()
+
+        is_true(isinstance(User.__table__.c.uuid.type, sqltypes.Uuid))
+
+    def test_type_inline_cls_qualified(self, decl_base):
+        """test #10899, where we test that we can refer to the class name
+        directly to refer to class-bound elements.
+
+        """
+
+        class User(decl_base):
+            __tablename__ = "user"
+
+            class Role(enum.Enum):
+                admin = "admin"
+                user = "user"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            role: Mapped[User.Role]
+
+        is_true(isinstance(User.__table__.c.role.type, sqltypes.Enum))
+        eq_(User.__table__.c.role.type.length, 5)
+        is_(User.__table__.c.role.type.enum_class, User.Role)
+
+    def test_type_inline_disambiguate(self, decl_base):
+        """test #10899, where we test that we can refer to an inner name
+        that's not in conflict directly without qualification.
+
+        """
+
+        class User(decl_base):
+            __tablename__ = "user"
+
+            class Role(enum.Enum):
+                admin = "admin"
+                user = "user"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            role: Mapped[Role]
+
+        is_true(isinstance(User.__table__.c.role.type, sqltypes.Enum))
+        eq_(User.__table__.c.role.type.length, 5)
+        is_(User.__table__.c.role.type.enum_class, User.Role)
+        eq_(User.__table__.c.role.type.name, "role")  # and not 'enum'
+
+    def test_type_inner_can_be_qualified(self, decl_base):
+        """test #10899, same test as that of Role, using it to qualify against
+        a global variable with the same name.
+
+        """
+
+        global SomeGlobalName
+        SomeGlobalName = None
+
+        class User(decl_base):
+            __tablename__ = "user"
+
+            class SomeGlobalName(enum.Enum):
+                admin = "admin"
+                user = "user"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            role: Mapped[User.SomeGlobalName]
+
+        is_true(isinstance(User.__table__.c.role.type, sqltypes.Enum))
+        eq_(User.__table__.c.role.type.length, 5)
+        is_(User.__table__.c.role.type.enum_class, User.SomeGlobalName)
+
     def test_indirect_mapped_name_local_level(self, decl_base):
         """test #8759.
 
index d2f2a0261f367bbf3465c01b2b43ec3c5e5b18e7..72c54cbca21d9dfc7760b94c9c22f845a2d80165 100644 (file)
@@ -192,6 +192,46 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         else:
             eq_(Foo.__table__.c.data.default.arg, 5)
 
+    def test_type_inline_declaration(self, decl_base):
+        """test #10899"""
+
+        class User(decl_base):
+            __tablename__ = "user"
+
+            class Role(enum.Enum):
+                admin = "admin"
+                user = "user"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            role: Mapped[Role]
+
+        is_true(isinstance(User.__table__.c.role.type, Enum))
+        eq_(User.__table__.c.role.type.length, 5)
+        is_(User.__table__.c.role.type.enum_class, User.Role)
+        eq_(User.__table__.c.role.type.name, "role")  # and not 'enum'
+
+    def test_type_uses_inner_when_present(self, decl_base):
+        """test #10899, that we use inner name when appropriate"""
+
+        class Role(enum.Enum):
+            foo = "foo"
+            bar = "bar"
+
+        class User(decl_base):
+            __tablename__ = "user"
+
+            class Role(enum.Enum):
+                admin = "admin"
+                user = "user"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            role: Mapped[Role]
+
+        is_true(isinstance(User.__table__.c.role.type, Enum))
+        eq_(User.__table__.c.role.type.length, 5)
+        is_(User.__table__.c.role.type.enum_class, User.Role)
+        eq_(User.__table__.c.role.type.name, "role")  # and not 'enum'
+
     def test_legacy_declarative_base(self):
         typ = VARCHAR(50)
         Base = declarative_base(type_annotation_map={str: typ})
index 37aa216d543577692644de941ea9e7253e657531..ed36ea2dce68076bd61222b0b48b995260c73c90 100644 (file)
@@ -183,6 +183,46 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         else:
             eq_(Foo.__table__.c.data.default.arg, 5)
 
+    def test_type_inline_declaration(self, decl_base):
+        """test #10899"""
+
+        class User(decl_base):
+            __tablename__ = "user"
+
+            class Role(enum.Enum):
+                admin = "admin"
+                user = "user"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            role: Mapped[Role]
+
+        is_true(isinstance(User.__table__.c.role.type, Enum))
+        eq_(User.__table__.c.role.type.length, 5)
+        is_(User.__table__.c.role.type.enum_class, User.Role)
+        eq_(User.__table__.c.role.type.name, "role")  # and not 'enum'
+
+    def test_type_uses_inner_when_present(self, decl_base):
+        """test #10899, that we use inner name when appropriate"""
+
+        class Role(enum.Enum):
+            foo = "foo"
+            bar = "bar"
+
+        class User(decl_base):
+            __tablename__ = "user"
+
+            class Role(enum.Enum):
+                admin = "admin"
+                user = "user"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            role: Mapped[Role]
+
+        is_true(isinstance(User.__table__.c.role.type, Enum))
+        eq_(User.__table__.c.role.type.length, 5)
+        is_(User.__table__.c.role.type.enum_class, User.Role)
+        eq_(User.__table__.c.role.type.name, "role")  # and not 'enum'
+
     def test_legacy_declarative_base(self):
         typ = VARCHAR(50)
         Base = declarative_base(type_annotation_map={str: typ})