]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
update pickle tests
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 11 Jun 2025 18:55:14 +0000 (14:55 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 11 Jun 2025 19:19:23 +0000 (15:19 -0400)
Since I want to get rid of util.portable_instancemethod, first
make sure we are testing pickle extensively including going through
all protocols for all metadata-oriented tests.

Change-Id: I0064bc16033939780e50c7a8a4ede60ef5835b38

lib/sqlalchemy/dialects/mysql/types.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/testing/fixtures/base.py
lib/sqlalchemy/testing/util.py
test/ext/test_serializer.py
test/sql/test_metadata.py

index 8621f5b9864dd4dd62d2cf1c7caf178367483fb0..d88aace2cc3e3c542f4a79cc7432bcd9ff0677df 100644 (file)
@@ -23,6 +23,7 @@ if TYPE_CHECKING:
     from ...engine.interfaces import Dialect
     from ...sql.type_api import _BindProcessorType
     from ...sql.type_api import _ResultProcessorType
+    from ...sql.type_api import TypeEngine
 
 
 class _NumericCommonType:
@@ -395,6 +396,12 @@ class TINYINT(_IntegerType):
         """
         super().__init__(display_width=display_width, **kw)
 
+    def _compare_type_affinity(self, other: TypeEngine[Any]) -> bool:
+        return (
+            self._type_affinity is other._type_affinity
+            or other._type_affinity is sqltypes.Boolean
+        )
+
 
 class SMALLINT(_IntegerType, sqltypes.SMALLINT):
     """MySQL SMALLINTEGER type."""
index 1c324501759b90d4bae8672163f07f141be51b6f..24aa16daa148533bd039dd2b96275b1072527e0b 100644 (file)
@@ -1608,6 +1608,12 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
             self.enum_class = None
             return enums, enums  # type: ignore[return-value]
 
+    def _compare_type_affinity(self, other: TypeEngine[Any]) -> bool:
+        return (
+            super()._compare_type_affinity(other)
+            or other._type_affinity is String
+        )
+
     def _resolve_for_literal(self, value: Any) -> Enum:
         tv = type(value)
         typ = self._resolve_for_python_type(tv, tv, tv)
index 09d45a0a2205e801d3d04ce4a5149f638ca2b769..270a1b7d73ebbd00f67715fd94772f42ad5ad423 100644 (file)
@@ -14,6 +14,7 @@ from .. import assertions
 from .. import config
 from ..assertions import eq_
 from ..util import drop_all_tables_from_metadata
+from ..util import picklers
 from ... import Column
 from ... import func
 from ... import Integer
@@ -194,6 +195,10 @@ class TestBase:
 
         return go
 
+    @config.fixture(params=picklers())
+    def picklers(self, request):
+        yield request.param
+
     @config.fixture()
     def metadata(self, request):
         """Provide bound MetaData for a single test, dropping afterwards."""
index 42f077108f588bebd050312c487f8b4d6e7b34c6..21dddfa2ec17bd6526af25be65f1576d5cb0362c 100644 (file)
 from __future__ import annotations
 
 from collections import deque
+from collections import namedtuple
 import contextlib
 import decimal
 import gc
 from itertools import chain
+import pickle
 import random
 import sys
 from sys import getsizeof
@@ -55,15 +57,10 @@ else:
 
 
 def picklers():
-    picklers = set()
-    import pickle
+    nt = namedtuple("picklers", ["loads", "dumps"])
 
-    picklers.add(pickle)
-
-    # yes, this thing needs this much testing
-    for pickle_ in picklers:
-        for protocol in range(-2, pickle.HIGHEST_PROTOCOL + 1):
-            yield pickle_.loads, lambda d: pickle_.dumps(d, protocol)
+    for protocol in range(-2, pickle.HIGHEST_PROTOCOL + 1):
+        yield nt(pickle.loads, lambda d: pickle.dumps(d, protocol))
 
 
 def random_choices(population, k=1):
index fb92c752a67707314ebc2d2804c891b490d4e73c..ffda82a538e4bd50b5094953ce2049e318bbddf7 100644 (file)
@@ -1,3 +1,5 @@
+import pickle
+
 from sqlalchemy import desc
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
@@ -27,8 +29,7 @@ from sqlalchemy.testing.schema import Table
 
 
 def pickle_protocols():
-    return iter([-1, 1, 2])
-    # return iter([-1, 0, 1, 2])
+    return range(-2, pickle.HIGHEST_PROTOCOL)
 
 
 class User(ComparableEntity):
index 0b5f70573204981ef84005404230f6d973d2ea87..e963fca6a3b5b8ca6595d76902cdb1eee6d8547a 100644 (file)
@@ -520,7 +520,7 @@ class MetaDataTest(fixtures.TestBase, ComparesTables):
         t.c.x._init_items(s1)
         assert s1.metadata is m1
 
-    def test_pickle_metadata_sequence_implicit(self):
+    def test_pickle_metadata_sequence_implicit(self, picklers):
         m1 = MetaData()
         Table(
             "a",
@@ -529,13 +529,13 @@ class MetaDataTest(fixtures.TestBase, ComparesTables):
             Column("x", Integer, Sequence("x_seq")),
         )
 
-        m2 = pickle.loads(pickle.dumps(m1))
+        m2 = picklers.loads(picklers.dumps(m1))
 
         t2 = Table("a", m2, extend_existing=True)
 
         eq_(m2._sequences, {"x_seq": t2.c.x.default})
 
-    def test_pickle_metadata_schema(self):
+    def test_pickle_metadata_schema(self, picklers):
         m1 = MetaData()
         Table(
             "a",
@@ -545,7 +545,7 @@ class MetaDataTest(fixtures.TestBase, ComparesTables):
             schema="y",
         )
 
-        m2 = pickle.loads(pickle.dumps(m1))
+        m2 = picklers.loads(picklers.dumps(m1))
 
         Table("a", m2, schema="y", extend_existing=True)
 
@@ -813,19 +813,27 @@ class MetaDataTest(fixtures.TestBase, ComparesTables):
 
 
 class ToMetaDataTest(fixtures.TestBase, AssertsCompiledSQL, ComparesTables):
-    @testing.requires.check_constraints
-    def test_copy(self):
-        # TODO: modernize this test for 2.0
 
+    @testing.fixture
+    def copy_fixture(self, metadata):
         from sqlalchemy.testing.schema import Table
 
-        meta = MetaData()
-
         table = Table(
             "mytable",
-            meta,
+            metadata,
             Column("myid", Integer, Sequence("foo_id_seq"), primary_key=True),
             Column("name", String(40), nullable=True),
+            Column("status", Boolean(create_constraint=True)),
+            Column(
+                "entry",
+                Enum(
+                    "one",
+                    "two",
+                    "three",
+                    name="entry_enum",
+                    create_constraint=True,
+                ),
+            ),
             Column(
                 "foo",
                 String(40),
@@ -845,7 +853,7 @@ class ToMetaDataTest(fixtures.TestBase, AssertsCompiledSQL, ComparesTables):
 
         table2 = Table(
             "othertable",
-            meta,
+            metadata,
             Column("id", Integer, Sequence("foo_seq"), primary_key=True),
             Column("myid", Integer, ForeignKey("mytable.myid")),
             test_needs_fk=True,
@@ -853,103 +861,119 @@ class ToMetaDataTest(fixtures.TestBase, AssertsCompiledSQL, ComparesTables):
 
         table3 = Table(
             "has_comments",
-            meta,
+            metadata,
             Column("foo", Integer, comment="some column"),
             comment="table comment",
         )
 
-        def test_to_metadata():
+        metadata.create_all(testing.db)
+
+        return table, table2, table3
+
+    @testing.fixture(
+        params=[
+            "to_metadata",
+            "pickle",
+            "pickle_via_reflect",
+        ]
+    )
+    def copy_tables_fixture(self, request, metadata, copy_fixture, picklers):
+        table, table2, table3 = copy_fixture
+
+        test = request.param
+
+        if test == "to_metadata":
             meta2 = MetaData()
             table_c = table.to_metadata(meta2)
             table2_c = table2.to_metadata(meta2)
             table3_c = table3.to_metadata(meta2)
-            return (table_c, table2_c, table3_c)
+            return (table_c, table2_c, table3_c, (True, False))
 
-        def test_pickle():
-            meta.bind = testing.db
-            meta2 = pickle.loads(pickle.dumps(meta))
-            pickle.loads(pickle.dumps(meta2))
+        elif test == "pickle":
+            meta2 = picklers.loads(picklers.dumps(metadata))
+            picklers.loads(picklers.dumps(meta2))
             return (
                 meta2.tables["mytable"],
                 meta2.tables["othertable"],
                 meta2.tables["has_comments"],
+                (True, False),
             )
 
-        def test_pickle_via_reflect():
+        elif test == "pickle_via_reflect":
             # this is the most common use case, pickling the results of a
             # database reflection
             meta2 = MetaData()
             t1 = Table("mytable", meta2, autoload_with=testing.db)
             Table("othertable", meta2, autoload_with=testing.db)
             Table("has_comments", meta2, autoload_with=testing.db)
-            meta3 = pickle.loads(pickle.dumps(meta2))
+            meta3 = picklers.loads(picklers.dumps(meta2))
             assert meta3.tables["mytable"] is not t1
 
             return (
                 meta3.tables["mytable"],
                 meta3.tables["othertable"],
                 meta3.tables["has_comments"],
+                (False, True),
             )
 
-        meta.create_all(testing.db)
-        try:
-            for test, has_constraints, reflect in (
-                (test_to_metadata, True, False),
-                (test_pickle, True, False),
-                (test_pickle_via_reflect, False, True),
-            ):
-                table_c, table2_c, table3_c = test()
-                self.assert_tables_equal(table, table_c)
-                self.assert_tables_equal(table2, table2_c)
-                assert table is not table_c
-                assert table.primary_key is not table_c.primary_key
-                assert (
-                    list(table2_c.c.myid.foreign_keys)[0].column
-                    is table_c.c.myid
-                )
-                assert (
-                    list(table2_c.c.myid.foreign_keys)[0].column
-                    is not table.c.myid
+        assert False
+
+    @testing.requires.check_constraints
+    def test_copy(self, metadata, copy_fixture, copy_tables_fixture):
+
+        table, table2, table3 = copy_fixture
+        table_c, table2_c, table3_c, (has_constraints, reflect) = (
+            copy_tables_fixture
+        )
+
+        self.assert_tables_equal(table, table_c)
+        self.assert_tables_equal(table2, table2_c)
+        assert table is not table_c
+        assert table.primary_key is not table_c.primary_key
+        assert list(table2_c.c.myid.foreign_keys)[0].column is table_c.c.myid
+        assert list(table2_c.c.myid.foreign_keys)[0].column is not table.c.myid
+        assert "x" in str(table_c.c.foo.server_default.arg)
+        if not reflect:
+            assert isinstance(table_c.c.myid.default, Sequence)
+            assert str(table_c.c.foo.server_onupdate.arg) == "q"
+            assert str(table_c.c.bar.default.arg) == "y"
+            assert (
+                getattr(
+                    table_c.c.bar.onupdate.arg,
+                    "arg",
+                    table_c.c.bar.onupdate.arg,
                 )
-                assert "x" in str(table_c.c.foo.server_default.arg)
-                if not reflect:
-                    assert isinstance(table_c.c.myid.default, Sequence)
-                    assert str(table_c.c.foo.server_onupdate.arg) == "q"
-                    assert str(table_c.c.bar.default.arg) == "y"
-                    assert (
-                        getattr(
-                            table_c.c.bar.onupdate.arg,
-                            "arg",
-                            table_c.c.bar.onupdate.arg,
-                        )
-                        == "z"
-                    )
-                    assert isinstance(table2_c.c.id.default, Sequence)
-
-                # constraints don't get reflected for any dialect right
-                # now
-
-                if has_constraints:
-                    for c in table_c.c.description.constraints:
-                        if isinstance(c, CheckConstraint):
-                            break
-                    else:
-                        assert False
-                    assert str(c.sqltext) == "description='hi'"
-                    for c in table_c.constraints:
-                        if isinstance(c, UniqueConstraint):
-                            break
-                    else:
-                        assert False
-                    assert c.columns.contains_column(table_c.c.name)
-                    assert not c.columns.contains_column(table.c.name)
-
-                if testing.requires.comment_reflection.enabled:
-                    eq_(table3_c.comment, "table comment")
-                    eq_(table3_c.c.foo.comment, "some column")
+                == "z"
+            )
+            assert isinstance(table2_c.c.id.default, Sequence)
 
-        finally:
-            meta.drop_all(testing.db)
+        if testing.requires.unique_constraint_reflection.enabled:
+            for c in table_c.constraints:
+                if isinstance(c, UniqueConstraint):
+                    break
+            else:
+                for c in table_c.indexes:
+                    break
+                else:
+                    assert False
+
+            assert c.columns.contains_column(table_c.c.name)
+            assert not c.columns.contains_column(table.c.name)
+
+        # CHECK constraints don't get reflected for any dialect right
+        # now
+
+        if has_constraints:
+            for c in table_c.c.description.constraints:
+                if isinstance(c, CheckConstraint):
+                    break
+            else:
+                assert False
+            assert str(c.sqltext) == "description='hi'"
+
+        if testing.requires.comment_reflection.enabled:
+            eq_(table3_c.comment, "table comment")
+            eq_(table3_c.c.foo.comment, "some column")
 
     def test_col_key_fk_parent(self):
         # test #2643