From f8d15739d36da138d34fb3fecac1fa043e65e48d Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 2 Jul 2022 11:49:56 -0400 Subject: [PATCH] call toinstance() on type arguments passed to mapped_column() Change-Id: I875cfbd925cb08e0a5235f87d13341d319c955bc --- lib/sqlalchemy/orm/properties.py | 3 ++- lib/sqlalchemy/sql/sqltypes.py | 2 +- lib/sqlalchemy/sql/type_api.py | 6 +++-- test/orm/declarative/test_typed_mapping.py | 27 ++++++++++++++++++++++ 4 files changed, 34 insertions(+), 4 deletions(-) diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 7308b8fb12..c5f50d7b45 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -684,4 +684,5 @@ class MappedColumn( f"Could not locate SQLAlchemy Core " f"type for Python type: {our_type}" ) - self.column.type = new_sqltype # type: ignore + + self.column.type = sqltypes.to_instance(new_sqltype) diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index c67614070f..de833cd893 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -46,7 +46,7 @@ from .elements import Slice from .elements import TypeCoerce as type_coerce # noqa from .type_api import Emulated from .type_api import NativeForEmulated # noqa -from .type_api import to_instance +from .type_api import to_instance as to_instance from .type_api import TypeDecorator as TypeDecorator from .type_api import TypeEngine as TypeEngine from .type_api import TypeEngineMixin diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 46bf151eaf..efaf5d2a79 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -41,6 +41,7 @@ from ..util.typing import TypeGuard # these are back-assigned by sqltypes. if typing.TYPE_CHECKING: + from ._typing import _TypeEngineArgument from .elements import BindParameter from .elements import ColumnElement from .operators import OperatorType @@ -55,7 +56,6 @@ if typing.TYPE_CHECKING: from .sqltypes import TABLEVALUE as TABLEVALUE # noqa: F401 from ..engine.interfaces import Dialect - _T = TypeVar("_T", bound=Any) _T_co = TypeVar("_T_co", bound=Any, covariant=True) _T_con = TypeVar("_T_con", bound=Any, contravariant=True) @@ -642,7 +642,9 @@ class TypeEngine(Visitable, Generic[_T]): raise NotImplementedError() def with_variant( - self: SelfTypeEngine, type_: TypeEngine[Any], *dialect_names: str + self: SelfTypeEngine, + type_: _TypeEngineArgument[Any], + *dialect_names: str, ) -> SelfTypeEngine: r"""Produce a copy of this type object that will utilize the given type when applied to the dialect of the given name. diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index 6f60a652ff..beb5d783bf 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -39,6 +39,7 @@ from sqlalchemy.orm import relationship from sqlalchemy.orm import undefer from sqlalchemy.orm.collections import attribute_mapped_collection from sqlalchemy.orm.collections import MappedCollection +from sqlalchemy.schema import CreateTable from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises from sqlalchemy.testing import expect_raises_message @@ -100,6 +101,32 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_(MyClass.__table__.c.data.type, typ) is_true(MyClass.__table__.c.id.primary_key) + @testing.combinations( + (BIGINT(),), + (BIGINT,), + (Integer().with_variant(BIGINT, "default")), + (Integer().with_variant(BIGINT(), "default")), + (BIGINT().with_variant(String(), "some_other_dialect")), + ) + def test_type_map_varieties(self, typ): + + Base = declarative_base(type_annotation_map={int: typ}) + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + x: Mapped[int] + y: Mapped[int] = mapped_column() + z: Mapped[int] = mapped_column(typ) + + self.assert_compile( + CreateTable(MyClass.__table__), + "CREATE TABLE mytable (id BIGINT NOT NULL, " + "x BIGINT NOT NULL, y BIGINT NOT NULL, z BIGINT NOT NULL, " + "PRIMARY KEY (id))", + ) + def test_required_no_arg(self, decl_base): with expect_raises_message( sa_exc.ArgumentError, -- 2.47.2