]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support `TypeDecorator` subclasses in `Column()` declarations
authorBryan Forbes <bryan@reigndropsfall.net>
Wed, 7 Apr 2021 23:56:27 +0000 (18:56 -0500)
committerBryan Forbes <bryan@reigndropsfall.net>
Wed, 7 Apr 2021 23:56:27 +0000 (18:56 -0500)
lib/sqlalchemy/ext/mypy/infer.py
test/ext/mypy/files/type_decorator.py [new file with mode: 0644]

index 49dd9fb7435fba5db26ad03785106583a4b9312c..f0f6be36f35d7278bca5748ff13b02f0b0249b75 100644 (file)
@@ -10,6 +10,7 @@ from typing import Union
 
 from mypy import nodes
 from mypy import types
+from mypy.maptype import map_instance_to_supertype
 from mypy.messages import format_type
 from mypy.nodes import AssignmentStmt
 from mypy.nodes import CallExpr
@@ -413,9 +414,11 @@ def _extract_python_type_from_typeengine(
             n = api.lookup_fully_qualified("builtins.str")
             return Instance(n.node, [])
 
-    for mr in node.mro:
-        if mr.bases:
-            for base_ in mr.bases:
-                if base_.type.fullname == "sqlalchemy.sql.type_api.TypeEngine":
-                    return base_.args[-1]
-    assert False, "could not extract Python type from node: %s" % node
+    assert node.has_base("sqlalchemy.sql.type_api.TypeEngine"), (
+        "could not extract Python type from node: %s" % node
+    )
+    type_engine = map_instance_to_supertype(
+        Instance(node, []),
+        api.modules["sqlalchemy.sql.type_api"].names["TypeEngine"].node,
+    )
+    return type_engine.args[-1]
diff --git a/test/ext/mypy/files/type_decorator.py b/test/ext/mypy/files/type_decorator.py
new file mode 100644 (file)
index 0000000..f4ab4fd
--- /dev/null
@@ -0,0 +1,45 @@
+from typing import Any
+from typing import Optional
+
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy import TypeDecorator
+from sqlalchemy.ext.declarative import declarative_base
+
+Base = declarative_base()
+
+
+class IntToStr(TypeDecorator[int]):
+    impl = String
+
+    def process_bind_param(
+        self,
+        value: Any,
+        dialect: Any,
+    ) -> Optional[str]:
+        return str(value) if value is not None else value
+
+    def process_result_value(
+        self,
+        value: Any,
+        dialect: Any,
+    ) -> Optional[int]:
+        return int(value) if value is not None else value
+
+    def copy(self, /, **kwargs: Any) -> "IntToStr":
+        return IntToStr(self.impl.length)
+
+
+class Thing(Base):
+    __tablename__ = "things"
+
+    id: int = Column(Integer, primary_key=True)
+    intToStr: int = Column(IntToStr)
+
+
+t1 = Thing(intToStr=5)
+
+i5: int = t1.intToStr
+
+t1.intToStr = 8