]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add support for Boolean, Enum
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 21 Mar 2021 21:10:18 +0000 (17:10 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 21 Mar 2021 21:23:22 +0000 (17:23 -0400)
Fixed bug in Mypy plugin where the Python type detection
for the :class:`_sqltypes.Boolean` column type would produce
an exception; additionally implemented support for :class:`_sqltypes.Enum`,
including detection of a string-based enum vs. use of Python ``enum.Enum``.

Fixes: #6109
Change-Id: I25e546ea2f50d90be2d6fec303976d82849a3d31

lib/sqlalchemy/ext/mypy/decl_class.py
lib/sqlalchemy/sql/sqltypes.py
test/ext/mypy/files/boolean_col.py [new file with mode: 0644]
test/ext/mypy/files/enum_col.py [new file with mode: 0644]

index 7a0c251c3b3ace0453fd5f2127ba20573d5227be..9fb1fa807a869fd3791d11cbe8c17b7bec28869e 100644 (file)
@@ -201,7 +201,9 @@ def _scan_declarative_decorator_stmt(
 
                     left_hand_explicit_type = UnionType(
                         [
-                            _extract_python_type_from_typeengine(sym.node),
+                            _extract_python_type_from_typeengine(
+                                api, sym.node, []
+                            ),
                             NoneType(),
                         ]
                     )
@@ -692,11 +694,13 @@ def _infer_type_from_decl_column(
         if isinstance(column_arg, nodes.CallExpr):
             # x = Column(String(50))
             callee = column_arg.callee
+            type_args = column_arg.args
             break
         elif isinstance(column_arg, (nodes.NameExpr, nodes.MemberExpr)):
             if isinstance(column_arg.node, TypeInfo):
                 # x = Column(String)
                 callee = column_arg
+                type_args = ()
                 break
             else:
                 # x = Column(some_name, String), go to next argument
@@ -710,9 +714,11 @@ def _infer_type_from_decl_column(
     if callee is None:
         return None
 
-    if names._mro_has_id(callee.node.mro, names.TYPEENGINE):
+    if isinstance(callee.node, TypeInfo) and names._mro_has_id(
+        callee.node.mro, names.TYPEENGINE
+    ):
         python_type_for_type = _extract_python_type_from_typeengine(
-            callee.node
+            api, callee.node, type_args
         )
 
         if left_hand_explicit_type is not None:
@@ -977,13 +983,25 @@ def _apply_placeholder_attr_to_class(
     cls.info.names[attrname] = SymbolTableNode(MDEF, var)
 
 
-def _extract_python_type_from_typeengine(node: TypeInfo) -> Instance:
-    for mr in node.mro:
-        if (
-            mr.bases
-            and mr.bases[-1].type.fullname
-            == "sqlalchemy.sql.type_api.TypeEngine"
+def _extract_python_type_from_typeengine(
+    api: SemanticAnalyzerPluginInterface, node: TypeInfo, type_args
+) -> Instance:
+    if node.fullname == "sqlalchemy.sql.sqltypes.Enum" and type_args:
+        first_arg = type_args[0]
+        if isinstance(first_arg, NameExpr) and isinstance(
+            first_arg.node, TypeInfo
         ):
-            return mr.bases[-1].args[-1]
-    else:
-        assert False, "could not extract Python type from node: %s" % node
+            for base_ in first_arg.node.mro:
+                if base_.fullname == "enum.Enum":
+                    return Instance(first_arg.node, [])
+            # TODO: support other pep-435 types here
+        else:
+            n = api.lookup_fully_qualified("builtins.str")
+            return Instance(n.node, [])
+
+    for mr in node.mro:
+        if mr.bases:
+            for base_ in mr.bases:
+                if base_.type.fullname == "sqlalchemy.sql.type_api.TypeEngine":
+                    return base_.args[-1]
+    assert False, "could not extract Python type from node: %s" % node
index 816423d1b6456e89abc7c7bba7356431fab926b3..a73c611476c45ab991560bd13f2fcc41cf76a211 100644 (file)
@@ -1324,15 +1324,15 @@ class Enum(Emulated, String, SchemaType):
         by that backend.
 
         :param \*enums: either exactly one PEP-435 compliant enumerated type
-           or one or more string or unicode enumeration labels. If unicode
-           labels are present, the `convert_unicode` flag is auto-enabled.
+           or one or more string labels.
 
            .. versionadded:: 1.1 a PEP-435 style enumerated class may be
               passed.
 
         :param convert_unicode: Enable unicode-aware bind parameter and
-           result-set processing for this Enum's data. This is set
-           automatically based on the presence of unicode label strings.
+           result-set processing for this Enum's data under Python 2 only.
+           Under Python 2, this is set automatically based on the presence of
+           unicode label strings.  This flag will be removed in SQLAlchemy 2.0.
 
         :param create_constraint: defaults to False.  When creating a
            non-native enumerated type, also build a CHECK constraint on the
diff --git a/test/ext/mypy/files/boolean_col.py b/test/ext/mypy/files/boolean_col.py
new file mode 100644 (file)
index 0000000..3e361ad
--- /dev/null
@@ -0,0 +1,24 @@
+from typing import Optional
+
+from sqlalchemy import Boolean
+from sqlalchemy import Column
+from sqlalchemy.orm import declarative_base
+
+Base = declarative_base()
+
+
+class TestBoolean(Base):
+    __tablename__ = "test_boolean"
+
+    flag = Column(Boolean)
+
+    bflag: bool = Column(Boolean(create_constraint=True))
+
+
+expr = TestBoolean.flag.is_(True)
+
+t1 = TestBoolean(flag=True)
+
+x: Optional[bool] = t1.flag
+
+y: bool = t1.bflag
diff --git a/test/ext/mypy/files/enum_col.py b/test/ext/mypy/files/enum_col.py
new file mode 100644 (file)
index 0000000..cfea388
--- /dev/null
@@ -0,0 +1,40 @@
+import enum
+from typing import Optional
+
+from sqlalchemy import Column
+from sqlalchemy import Enum
+from sqlalchemy.orm import declarative_base
+
+
+class MyEnum(enum.Enum):
+    one = 1
+    two = 2
+    three = 3
+
+
+Base = declarative_base()
+
+one, two, three = "one", "two", "three"
+
+
+class TestEnum(Base):
+    __tablename__ = "test_enum"
+
+    e1: str = Column(Enum("one", "two", "three"))
+
+    e2: MyEnum = Column(Enum(MyEnum))
+
+    e3 = Column(Enum(one, two, three))
+
+    e4 = Column(Enum(MyEnum))
+
+
+t1 = TestEnum(e1="two", e2=MyEnum.three, e3="one", e4=MyEnum.one)
+
+x: str = t1.e1
+
+y: MyEnum = t1.e2
+
+z: Optional[str] = t1.e3
+
+z2: Optional[MyEnum] = t1.e4