From: Bryan Forbes Date: Wed, 7 Apr 2021 23:56:27 +0000 (-0500) Subject: Support `TypeDecorator` subclasses in `Column()` declarations X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=acc42beaf630db2d5b4166098ad062e1df21ef51;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Support `TypeDecorator` subclasses in `Column()` declarations --- diff --git a/lib/sqlalchemy/ext/mypy/infer.py b/lib/sqlalchemy/ext/mypy/infer.py index 49dd9fb743..f0f6be36f3 100644 --- a/lib/sqlalchemy/ext/mypy/infer.py +++ b/lib/sqlalchemy/ext/mypy/infer.py @@ -10,6 +10,7 @@ from typing import Union from mypy import nodes from mypy import types +from mypy.maptype import map_instance_to_supertype from mypy.messages import format_type from mypy.nodes import AssignmentStmt from mypy.nodes import CallExpr @@ -413,9 +414,11 @@ def _extract_python_type_from_typeengine( n = api.lookup_fully_qualified("builtins.str") return Instance(n.node, []) - for mr in node.mro: - if mr.bases: - for base_ in mr.bases: - if base_.type.fullname == "sqlalchemy.sql.type_api.TypeEngine": - return base_.args[-1] - assert False, "could not extract Python type from node: %s" % node + assert node.has_base("sqlalchemy.sql.type_api.TypeEngine"), ( + "could not extract Python type from node: %s" % node + ) + type_engine = map_instance_to_supertype( + Instance(node, []), + api.modules["sqlalchemy.sql.type_api"].names["TypeEngine"].node, + ) + return type_engine.args[-1] diff --git a/test/ext/mypy/files/type_decorator.py b/test/ext/mypy/files/type_decorator.py new file mode 100644 index 0000000000..f4ab4fd60c --- /dev/null +++ b/test/ext/mypy/files/type_decorator.py @@ -0,0 +1,45 @@ +from typing import Any +from typing import Optional + +from sqlalchemy import Column +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy import TypeDecorator +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() + + +class IntToStr(TypeDecorator[int]): + impl = String + + def process_bind_param( + self, + value: Any, + dialect: Any, + ) -> Optional[str]: + return str(value) if value is not None else value + + def process_result_value( + self, + value: Any, + dialect: Any, + ) -> Optional[int]: + return int(value) if value is not None else value + + def copy(self, /, **kwargs: Any) -> "IntToStr": + return IntToStr(self.impl.length) + + +class Thing(Base): + __tablename__ = "things" + + id: int = Column(Integer, primary_key=True) + intToStr: int = Column(IntToStr) + + +t1 = Thing(intToStr=5) + +i5: int = t1.intToStr + +t1.intToStr = 8