]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fix(6435): support `MemberExpr` for enum column declaration
authorHiroshi Ogawa <hi.ogawa.zz@gmail.com>
Fri, 1 Oct 2021 22:59:22 +0000 (18:59 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 5 Oct 2021 16:49:04 +0000 (12:49 -0400)
Fixed issue in mypy plugin to improve upon some issues detecting ``Enum()``
SQL types containing custom Python enumeration classes. Pull request
courtesy Hiroshi Ogawa.

Fixes: #6435
Closes: #7048
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/7048
Pull-request-sha: 59f5c89688792f6af3b07488d5cf97f8f2e964dc

Change-Id: I05adbec74ceac1ecfdc5a242bfe7aa4b2eb805e4

doc/build/changelog/unreleased_14/6435.rst [new file with mode: 0644]
lib/sqlalchemy/ext/mypy/infer.py
test/ext/mypy/incremental/ticket_6435/__init__.py [new file with mode: 0644]
test/ext/mypy/incremental/ticket_6435/enum_col_import1.py [new file with mode: 0644]
test/ext/mypy/incremental/ticket_6435/enum_col_import2.py [new file with mode: 0644]

diff --git a/doc/build/changelog/unreleased_14/6435.rst b/doc/build/changelog/unreleased_14/6435.rst
new file mode 100644 (file)
index 0000000..d07754d
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, mypy
+    :tickets: 6435
+
+    Fixed issue in mypy plugin to improve upon some issues detecting ``Enum()``
+    SQL types containing custom Python enumeration classes. Pull request
+    courtesy Hiroshi Ogawa.
index 52570f772bd996c02a6e43a1ac1d8e7c3b71ac40..6d243b6ec1d9445df335915610aa880cfea825b6 100644 (file)
@@ -521,7 +521,7 @@ def extract_python_type_from_typeengine(
 ) -> ProperType:
     if node.fullname == "sqlalchemy.sql.sqltypes.Enum" and type_args:
         first_arg = type_args[0]
-        if isinstance(first_arg, NameExpr) and isinstance(
+        if isinstance(first_arg, RefExpr) and isinstance(
             first_arg.node, TypeInfo
         ):
             for base_ in first_arg.node.mro:
diff --git a/test/ext/mypy/incremental/ticket_6435/__init__.py b/test/ext/mypy/incremental/ticket_6435/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/test/ext/mypy/incremental/ticket_6435/enum_col_import1.py b/test/ext/mypy/incremental/ticket_6435/enum_col_import1.py
new file mode 100644 (file)
index 0000000..fbdbb4f
--- /dev/null
@@ -0,0 +1,11 @@
+import enum
+
+
+class StrEnum(enum.Enum):
+    one = "one"
+    two = "two"
+
+
+class IntEnum(enum.Enum):
+    one = 1
+    two = 2
diff --git a/test/ext/mypy/incremental/ticket_6435/enum_col_import2.py b/test/ext/mypy/incremental/ticket_6435/enum_col_import2.py
new file mode 100644 (file)
index 0000000..4f29932
--- /dev/null
@@ -0,0 +1,27 @@
+from sqlalchemy import Column
+from sqlalchemy import Enum
+from sqlalchemy.orm import declarative_base, Mapped
+from . import enum_col_import1
+from .enum_col_import1 import IntEnum, StrEnum
+
+Base = declarative_base()
+
+
+class TestEnum(Base):
+    __tablename__ = "test_enum"
+
+    e1: Mapped[StrEnum] = Column(Enum(StrEnum))
+    e2: StrEnum = Column(Enum(StrEnum))
+
+    e3: Mapped[IntEnum] = Column(Enum(IntEnum))
+    e4: IntEnum = Column(Enum(IntEnum))
+
+    e5: Mapped[enum_col_import1.StrEnum] = Column(
+        Enum(enum_col_import1.StrEnum)
+    )
+    e6: enum_col_import1.StrEnum = Column(Enum(enum_col_import1.StrEnum))
+
+    e7: Mapped[enum_col_import1.IntEnum] = Column(
+        Enum(enum_col_import1.IntEnum)
+    )
+    e8: enum_col_import1.IntEnum = Column(Enum(enum_col_import1.IntEnum))