From 33a3dcd112b77600093890582f50ee7bb9b7c9b7 Mon Sep 17 00:00:00 2001 From: Hiroshi Ogawa Date: Sun, 19 Sep 2021 15:59:20 +0900 Subject: [PATCH] fix: support `MemberExpr` for enum column declaration Fixes: #6435 --- lib/sqlalchemy/ext/mypy/infer.py | 2 +- test/ext/mypy/files/__init__.py | 0 test/ext/mypy/files/enum_col_import1.py | 9 +++++++++ test/ext/mypy/files/enum_col_import2.py | 24 ++++++++++++++++++++++++ 4 files changed, 34 insertions(+), 1 deletion(-) create mode 100644 test/ext/mypy/files/__init__.py create mode 100644 test/ext/mypy/files/enum_col_import1.py create mode 100644 test/ext/mypy/files/enum_col_import2.py diff --git a/lib/sqlalchemy/ext/mypy/infer.py b/lib/sqlalchemy/ext/mypy/infer.py index 52570f772b..6d243b6ec1 100644 --- a/lib/sqlalchemy/ext/mypy/infer.py +++ b/lib/sqlalchemy/ext/mypy/infer.py @@ -521,7 +521,7 @@ def extract_python_type_from_typeengine( ) -> ProperType: if node.fullname == "sqlalchemy.sql.sqltypes.Enum" and type_args: first_arg = type_args[0] - if isinstance(first_arg, NameExpr) and isinstance( + if isinstance(first_arg, RefExpr) and isinstance( first_arg.node, TypeInfo ): for base_ in first_arg.node.mro: diff --git a/test/ext/mypy/files/__init__.py b/test/ext/mypy/files/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/ext/mypy/files/enum_col_import1.py b/test/ext/mypy/files/enum_col_import1.py new file mode 100644 index 0000000000..4efae0e56d --- /dev/null +++ b/test/ext/mypy/files/enum_col_import1.py @@ -0,0 +1,9 @@ +import enum + +class StrEnum(enum.Enum): + one = "one" + two = "two" + +class IntEnum(enum.Enum): + one = 1 + two = 2 diff --git a/test/ext/mypy/files/enum_col_import2.py b/test/ext/mypy/files/enum_col_import2.py new file mode 100644 index 0000000000..7607d28c13 --- /dev/null +++ b/test/ext/mypy/files/enum_col_import2.py @@ -0,0 +1,24 @@ +from sqlalchemy import Column +from sqlalchemy import Enum +from sqlalchemy.orm import declarative_base, Mapped + +from .enum_col_import1 import StrEnum, IntEnum + +from . import enum_col_import1 + +Base = declarative_base() + +class TestEnum(Base): + __tablename__ = "test_enum" + + e1: Mapped[StrEnum] = Column(Enum(StrEnum)) + e2: StrEnum = Column(Enum(StrEnum)) + + e3: Mapped[IntEnum] = Column(Enum(IntEnum)) + e4: IntEnum = Column(Enum(IntEnum)) + + e5: Mapped[enum_col_import1.StrEnum] = Column(Enum(enum_col_import1.StrEnum)) + e6: enum_col_import1.StrEnum = Column(Enum(enum_col_import1.StrEnum)) + + e7: Mapped[enum_col_import1.IntEnum] = Column(Enum(enum_col_import1.IntEnum)) + e8: enum_col_import1.IntEnum = Column(Enum(enum_col_import1.IntEnum)) -- 2.47.3