]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add some typing tests for declared_attr, mixins, decl base
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 8 Jul 2022 18:36:25 +0000 (14:36 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 9 Jul 2022 14:40:49 +0000 (10:40 -0400)
to make typing easier, looking at using getattr on table.c
rather than getitem for now

Change-Id: I7946885071d0b0ddfc06be009f033495f9906de5

test/ext/mypy/plain_files/declared_attr_one.py [new file with mode: 0644]
test/ext/mypy/plain_files/declared_attr_two.py [new file with mode: 0644]
test/orm/declarative/test_mixin.py

diff --git a/test/ext/mypy/plain_files/declared_attr_one.py b/test/ext/mypy/plain_files/declared_attr_one.py
new file mode 100644 (file)
index 0000000..969c86c
--- /dev/null
@@ -0,0 +1,69 @@
+from datetime import datetime
+import typing
+
+from sqlalchemy import DateTime
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import declared_attr
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+
+
+class Base(DeclarativeBase):
+    pass
+
+
+class Employee(Base):
+    __tablename__ = "employee"
+    id = mapped_column(Integer, primary_key=True)
+    name = mapped_column(String(50))
+    type = mapped_column(String(20))
+
+    __mapper_args__ = {
+        "polymorphic_on": type,
+        "polymorphic_identity": "employee",
+    }
+
+
+class Engineer(Employee):
+    __mapper_args__ = {
+        "polymorphic_identity": "engineer",
+    }
+
+    @declared_attr
+    def start_date(cls) -> Mapped[datetime]:
+        "Start date column, if not present already."
+
+        assert Employee.__table__ is not None
+        return getattr(
+            Employee.__table__.c,
+            "start date",
+            mapped_column("start date", DateTime),
+        )
+
+
+class Manager(Employee):
+    __mapper_args__ = {
+        "polymorphic_identity": "manager",
+    }
+
+    @declared_attr
+    def start_date(cls) -> Mapped[datetime]:
+        "Start date column, if not present already."
+
+        assert Employee.__table__ is not None
+        return getattr(
+            Employee.__table__.c,
+            "start date",
+            mapped_column("start date", DateTime),
+        )
+
+
+if typing.TYPE_CHECKING:
+
+    # EXPECTED_TYPE: InstrumentedAttribute[datetime]
+    reveal_type(Engineer.start_date)
+
+    # EXPECTED_TYPE: InstrumentedAttribute[datetime]
+    reveal_type(Manager.start_date)
diff --git a/test/ext/mypy/plain_files/declared_attr_two.py b/test/ext/mypy/plain_files/declared_attr_two.py
new file mode 100644 (file)
index 0000000..86e679e
--- /dev/null
@@ -0,0 +1,38 @@
+import typing
+
+from sqlalchemy import Integer
+from sqlalchemy import Text
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import declared_attr
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+
+
+class Base(DeclarativeBase):
+    pass
+
+
+class HasRelatedDataMixin:
+    @declared_attr
+    def related_data(cls) -> Mapped[str]:
+        return mapped_column(Text(), deferred=True)
+
+
+class User(HasRelatedDataMixin, Base):
+    __tablename__ = "user"
+    id = mapped_column(Integer, primary_key=True)
+
+
+u1 = User()
+
+
+u1.related_data
+
+
+if typing.TYPE_CHECKING:
+
+    # EXPECTED_TYPE: str
+    reveal_type(u1.related_data)
+
+    # EXPECTED_TYPE: InstrumentedAttribute[str]
+    reveal_type(User.related_data)
index e6f669c3cdc73743940785481a266e81b95c34c7..72e14ceebd06c7dd12d921422e574eeec1563cd3 100644 (file)
@@ -16,6 +16,7 @@ from sqlalchemy.orm import column_property
 from sqlalchemy.orm import configure_mappers
 from sqlalchemy.orm import declarative_base
 from sqlalchemy.orm import declarative_mixin
+from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import declared_attr
 from sqlalchemy.orm import deferred
 from sqlalchemy.orm import events as orm_events
@@ -52,7 +53,9 @@ class DeclarativeTestBase(
         global Base, mapper_registry
 
         mapper_registry = registry(metadata=MetaData())
-        Base = mapper_registry.generate_base()
+
+        class Base(DeclarativeBase):
+            registry = mapper_registry
 
     def teardown_test(self):
         close_all_sessions()
@@ -62,12 +65,29 @@ class DeclarativeTestBase(
 
 
 class DeclarativeMixinTest(DeclarativeTestBase):
-    def test_init_subclass_works(self, registry):
-        class Base:
-            def __init_subclass__(cls):
-                cls.id = Column(Integer, primary_key=True)
+    @testing.combinations("generate_base", "subclass", argnames="base_type")
+    def test_init_subclass_works(self, registry, base_type):
+        reg = registry
+        if base_type == "generate_base":
+
+            class Base:
+                def __init_subclass__(cls):
+                    cls.id = Column(Integer, primary_key=True)
+
+            Base = registry.generate_base(cls=Base)
+        elif base_type == "subclass":
+
+            class Base(DeclarativeBase):
+                registry = reg
+
+                def __init_subclass__(cls):
+                    cls.id = Column(Integer, primary_key=True)
+                    # hmmm what do we think of this.  if DeclarativeBase
+                    # used a full metaclass approach we wouldn't need this.
+                    super().__init_subclass__()
 
-        Base = registry.generate_base(cls=Base)
+        else:
+            assert False
 
         class Foo(Base):
             __tablename__ = "foo"
@@ -127,10 +147,6 @@ class DeclarativeMixinTest(DeclarativeTestBase):
         eq_(obj.foo(), "bar1")
 
     def test_declarative_mixin_decorator(self):
-
-        # note we are also making sure an "old style class" in Python 2,
-        # as we are now illustrating in all the docs for mixins, doesn't cause
-        # a problem....
         @declarative_mixin
         class MyMixin:
 
@@ -141,10 +157,8 @@ class DeclarativeMixinTest(DeclarativeTestBase):
             def foo(self):
                 return "bar" + str(self.id)
 
-        # ...as long as the mapped class itself is "new style", which will
-        # normally be the case for users using declarative_base
         @mapper_registry.mapped
-        class MyModel(MyMixin, object):
+        class MyModel(MyMixin):
 
             __tablename__ = "test"
             name = Column(String(100), nullable=False, index=True)
@@ -645,7 +659,8 @@ class DeclarativeMixinTest(DeclarativeTestBase):
         lambda: testing.against("oracle"),
         "Test has an empty insert in it at the moment",
     )
-    def test_columns_single_inheritance_conflict_resolution(self):
+    @testing.combinations(Column, mapped_column, argnames="_column")
+    def test_columns_single_inheritance_conflict_resolution(self, _column):
         """Test that a declared_attr can return the existing column and it will
         be ignored.  this allows conditional columns to be added.
 
@@ -655,13 +670,13 @@ class DeclarativeMixinTest(DeclarativeTestBase):
 
         class Person(Base):
             __tablename__ = "person"
-            id = Column(Integer, primary_key=True)
+            id = _column(Integer, primary_key=True)
 
         class Mixin:
             @declared_attr
             def target_id(cls):
                 return cls.__table__.c.get(
-                    "target_id", Column(Integer, ForeignKey("other.id"))
+                    "target_id", _column(Integer, ForeignKey("other.id"))
                 )
 
             @declared_attr
@@ -678,7 +693,7 @@ class DeclarativeMixinTest(DeclarativeTestBase):
 
         class Other(Base):
             __tablename__ = "other"
-            id = Column(Integer, primary_key=True)
+            id = _column(Integer, primary_key=True)
 
         is_(
             Engineer.target_id.property.columns[0],
@@ -697,7 +712,8 @@ class DeclarativeMixinTest(DeclarativeTestBase):
         session.commit()
         eq_(session.query(Engineer).first().target, o1)
 
-    def test_columns_joined_table_inheritance(self):
+    @testing.combinations(Column, mapped_column, argnames="_column")
+    def test_columns_joined_table_inheritance(self, _column):
         """Test a column on a mixin with an alternate attribute name,
         mapped to a superclass and joined-table inheritance subclass.
         Both tables get the column, in the case of the subclass the two
@@ -706,18 +722,18 @@ class DeclarativeMixinTest(DeclarativeTestBase):
         """
 
         class MyMixin:
-            foo = Column("foo", Integer)
-            bar = Column("bar_newname", Integer)
+            foo = _column("foo", Integer)
+            bar = _column("bar_newname", Integer)
 
         class General(Base, MyMixin):
             __tablename__ = "test"
-            id = Column(Integer, primary_key=True)
-            type_ = Column(String(50))
+            id = _column(Integer, primary_key=True)
+            type_ = _column(String(50))
             __mapper__args = {"polymorphic_on": type_}
 
         class Specific(General):
             __tablename__ = "sub"
-            id = Column(Integer, ForeignKey("test.id"), primary_key=True)
+            id = _column(Integer, ForeignKey("test.id"), primary_key=True)
             __mapper_args__ = {"polymorphic_identity": "specific"}
 
         assert General.bar.prop.columns[0] is General.__table__.c.bar_newname