]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
call toinstance() on type arguments passed to mapped_column()
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 2 Jul 2022 15:49:56 +0000 (11:49 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 2 Jul 2022 17:26:36 +0000 (13:26 -0400)
Change-Id: I875cfbd925cb08e0a5235f87d13341d319c955bc

lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/type_api.py
test/orm/declarative/test_typed_mapping.py

index 7308b8fb1250c58f1e8115e797ca7135be986f42..c5f50d7b450edd1a5eb25d049d72ac6003c1838e 100644 (file)
@@ -684,4 +684,5 @@ class MappedColumn(
                     f"Could not locate SQLAlchemy Core "
                     f"type for Python type: {our_type}"
                 )
-            self.column.type = new_sqltype  # type: ignore
+
+            self.column.type = sqltypes.to_instance(new_sqltype)
index c67614070f88f75f117239f5b1e104b24463ccb9..de833cd89369b690c432597690a018bd88173d8e 100644 (file)
@@ -46,7 +46,7 @@ from .elements import Slice
 from .elements import TypeCoerce as type_coerce  # noqa
 from .type_api import Emulated
 from .type_api import NativeForEmulated  # noqa
-from .type_api import to_instance
+from .type_api import to_instance as to_instance
 from .type_api import TypeDecorator as TypeDecorator
 from .type_api import TypeEngine as TypeEngine
 from .type_api import TypeEngineMixin
index 46bf151eaf8f362aa3b4320d828372496f711b9c..efaf5d2a79201a04b39358be3a0263c996a44245 100644 (file)
@@ -41,6 +41,7 @@ from ..util.typing import TypeGuard
 
 # these are back-assigned by sqltypes.
 if typing.TYPE_CHECKING:
+    from ._typing import _TypeEngineArgument
     from .elements import BindParameter
     from .elements import ColumnElement
     from .operators import OperatorType
@@ -55,7 +56,6 @@ if typing.TYPE_CHECKING:
     from .sqltypes import TABLEVALUE as TABLEVALUE  # noqa: F401
     from ..engine.interfaces import Dialect
 
-
 _T = TypeVar("_T", bound=Any)
 _T_co = TypeVar("_T_co", bound=Any, covariant=True)
 _T_con = TypeVar("_T_con", bound=Any, contravariant=True)
@@ -642,7 +642,9 @@ class TypeEngine(Visitable, Generic[_T]):
         raise NotImplementedError()
 
     def with_variant(
-        self: SelfTypeEngine, type_: TypeEngine[Any], *dialect_names: str
+        self: SelfTypeEngine,
+        type_: _TypeEngineArgument[Any],
+        *dialect_names: str,
     ) -> SelfTypeEngine:
         r"""Produce a copy of this type object that will utilize the given
         type when applied to the dialect of the given name.
index 6f60a652ff1bc2be72af023c49250f47e6063510..beb5d783bf4ede070a7bc5fb9fb09a08ea540f1b 100644 (file)
@@ -39,6 +39,7 @@ from sqlalchemy.orm import relationship
 from sqlalchemy.orm import undefer
 from sqlalchemy.orm.collections import attribute_mapped_collection
 from sqlalchemy.orm.collections import MappedCollection
+from sqlalchemy.schema import CreateTable
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_raises
 from sqlalchemy.testing import expect_raises_message
@@ -100,6 +101,32 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         is_(MyClass.__table__.c.data.type, typ)
         is_true(MyClass.__table__.c.id.primary_key)
 
+    @testing.combinations(
+        (BIGINT(),),
+        (BIGINT,),
+        (Integer().with_variant(BIGINT, "default")),
+        (Integer().with_variant(BIGINT(), "default")),
+        (BIGINT().with_variant(String(), "some_other_dialect")),
+    )
+    def test_type_map_varieties(self, typ):
+
+        Base = declarative_base(type_annotation_map={int: typ})
+
+        class MyClass(Base):
+            __tablename__ = "mytable"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            x: Mapped[int]
+            y: Mapped[int] = mapped_column()
+            z: Mapped[int] = mapped_column(typ)
+
+        self.assert_compile(
+            CreateTable(MyClass.__table__),
+            "CREATE TABLE mytable (id BIGINT NOT NULL, "
+            "x BIGINT NOT NULL, y BIGINT NOT NULL, z BIGINT NOT NULL, "
+            "PRIMARY KEY (id))",
+        )
+
     def test_required_no_arg(self, decl_base):
         with expect_raises_message(
             sa_exc.ArgumentError,