From: Mike Bayer Date: Wed, 11 Jun 2025 18:55:14 +0000 (-0400) Subject: update pickle tests X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=239f629b9a94b315c289930cadca4a49f2f70565;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git update pickle tests Since I want to get rid of util.portable_instancemethod, first make sure we are testing pickle extensively including going through all protocols for all metadata-oriented tests. Change-Id: I0064bc16033939780e50c7a8a4ede60ef5835b38 --- diff --git a/lib/sqlalchemy/dialects/mysql/types.py b/lib/sqlalchemy/dialects/mysql/types.py index 8621f5b986..d88aace2cc 100644 --- a/lib/sqlalchemy/dialects/mysql/types.py +++ b/lib/sqlalchemy/dialects/mysql/types.py @@ -23,6 +23,7 @@ if TYPE_CHECKING: from ...engine.interfaces import Dialect from ...sql.type_api import _BindProcessorType from ...sql.type_api import _ResultProcessorType + from ...sql.type_api import TypeEngine class _NumericCommonType: @@ -395,6 +396,12 @@ class TINYINT(_IntegerType): """ super().__init__(display_width=display_width, **kw) + def _compare_type_affinity(self, other: TypeEngine[Any]) -> bool: + return ( + self._type_affinity is other._type_affinity + or other._type_affinity is sqltypes.Boolean + ) + class SMALLINT(_IntegerType, sqltypes.SMALLINT): """MySQL SMALLINTEGER type.""" diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 1c32450175..24aa16daa1 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -1608,6 +1608,12 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): self.enum_class = None return enums, enums # type: ignore[return-value] + def _compare_type_affinity(self, other: TypeEngine[Any]) -> bool: + return ( + super()._compare_type_affinity(other) + or other._type_affinity is String + ) + def _resolve_for_literal(self, value: Any) -> Enum: tv = type(value) typ = self._resolve_for_python_type(tv, tv, tv) diff --git a/lib/sqlalchemy/testing/fixtures/base.py b/lib/sqlalchemy/testing/fixtures/base.py index 09d45a0a22..270a1b7d73 100644 --- a/lib/sqlalchemy/testing/fixtures/base.py +++ b/lib/sqlalchemy/testing/fixtures/base.py @@ -14,6 +14,7 @@ from .. import assertions from .. import config from ..assertions import eq_ from ..util import drop_all_tables_from_metadata +from ..util import picklers from ... import Column from ... import func from ... import Integer @@ -194,6 +195,10 @@ class TestBase: return go + @config.fixture(params=picklers()) + def picklers(self, request): + yield request.param + @config.fixture() def metadata(self, request): """Provide bound MetaData for a single test, dropping afterwards.""" diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py index 42f077108f..21dddfa2ec 100644 --- a/lib/sqlalchemy/testing/util.py +++ b/lib/sqlalchemy/testing/util.py @@ -10,10 +10,12 @@ from __future__ import annotations from collections import deque +from collections import namedtuple import contextlib import decimal import gc from itertools import chain +import pickle import random import sys from sys import getsizeof @@ -55,15 +57,10 @@ else: def picklers(): - picklers = set() - import pickle + nt = namedtuple("picklers", ["loads", "dumps"]) - picklers.add(pickle) - - # yes, this thing needs this much testing - for pickle_ in picklers: - for protocol in range(-2, pickle.HIGHEST_PROTOCOL + 1): - yield pickle_.loads, lambda d: pickle_.dumps(d, protocol) + for protocol in range(-2, pickle.HIGHEST_PROTOCOL + 1): + yield nt(pickle.loads, lambda d: pickle.dumps(d, protocol)) def random_choices(population, k=1): diff --git a/test/ext/test_serializer.py b/test/ext/test_serializer.py index fb92c752a6..ffda82a538 100644 --- a/test/ext/test_serializer.py +++ b/test/ext/test_serializer.py @@ -1,3 +1,5 @@ +import pickle + from sqlalchemy import desc from sqlalchemy import ForeignKey from sqlalchemy import func @@ -27,8 +29,7 @@ from sqlalchemy.testing.schema import Table def pickle_protocols(): - return iter([-1, 1, 2]) - # return iter([-1, 0, 1, 2]) + return range(-2, pickle.HIGHEST_PROTOCOL) class User(ComparableEntity): diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index 0b5f705732..e963fca6a3 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -520,7 +520,7 @@ class MetaDataTest(fixtures.TestBase, ComparesTables): t.c.x._init_items(s1) assert s1.metadata is m1 - def test_pickle_metadata_sequence_implicit(self): + def test_pickle_metadata_sequence_implicit(self, picklers): m1 = MetaData() Table( "a", @@ -529,13 +529,13 @@ class MetaDataTest(fixtures.TestBase, ComparesTables): Column("x", Integer, Sequence("x_seq")), ) - m2 = pickle.loads(pickle.dumps(m1)) + m2 = picklers.loads(picklers.dumps(m1)) t2 = Table("a", m2, extend_existing=True) eq_(m2._sequences, {"x_seq": t2.c.x.default}) - def test_pickle_metadata_schema(self): + def test_pickle_metadata_schema(self, picklers): m1 = MetaData() Table( "a", @@ -545,7 +545,7 @@ class MetaDataTest(fixtures.TestBase, ComparesTables): schema="y", ) - m2 = pickle.loads(pickle.dumps(m1)) + m2 = picklers.loads(picklers.dumps(m1)) Table("a", m2, schema="y", extend_existing=True) @@ -813,19 +813,27 @@ class MetaDataTest(fixtures.TestBase, ComparesTables): class ToMetaDataTest(fixtures.TestBase, AssertsCompiledSQL, ComparesTables): - @testing.requires.check_constraints - def test_copy(self): - # TODO: modernize this test for 2.0 + @testing.fixture + def copy_fixture(self, metadata): from sqlalchemy.testing.schema import Table - meta = MetaData() - table = Table( "mytable", - meta, + metadata, Column("myid", Integer, Sequence("foo_id_seq"), primary_key=True), Column("name", String(40), nullable=True), + Column("status", Boolean(create_constraint=True)), + Column( + "entry", + Enum( + "one", + "two", + "three", + name="entry_enum", + create_constraint=True, + ), + ), Column( "foo", String(40), @@ -845,7 +853,7 @@ class ToMetaDataTest(fixtures.TestBase, AssertsCompiledSQL, ComparesTables): table2 = Table( "othertable", - meta, + metadata, Column("id", Integer, Sequence("foo_seq"), primary_key=True), Column("myid", Integer, ForeignKey("mytable.myid")), test_needs_fk=True, @@ -853,103 +861,119 @@ class ToMetaDataTest(fixtures.TestBase, AssertsCompiledSQL, ComparesTables): table3 = Table( "has_comments", - meta, + metadata, Column("foo", Integer, comment="some column"), comment="table comment", ) - def test_to_metadata(): + metadata.create_all(testing.db) + + return table, table2, table3 + + @testing.fixture( + params=[ + "to_metadata", + "pickle", + "pickle_via_reflect", + ] + ) + def copy_tables_fixture(self, request, metadata, copy_fixture, picklers): + table, table2, table3 = copy_fixture + + test = request.param + + if test == "to_metadata": meta2 = MetaData() table_c = table.to_metadata(meta2) table2_c = table2.to_metadata(meta2) table3_c = table3.to_metadata(meta2) - return (table_c, table2_c, table3_c) + return (table_c, table2_c, table3_c, (True, False)) - def test_pickle(): - meta.bind = testing.db - meta2 = pickle.loads(pickle.dumps(meta)) - pickle.loads(pickle.dumps(meta2)) + elif test == "pickle": + meta2 = picklers.loads(picklers.dumps(metadata)) + picklers.loads(picklers.dumps(meta2)) return ( meta2.tables["mytable"], meta2.tables["othertable"], meta2.tables["has_comments"], + (True, False), ) - def test_pickle_via_reflect(): + elif test == "pickle_via_reflect": # this is the most common use case, pickling the results of a # database reflection meta2 = MetaData() t1 = Table("mytable", meta2, autoload_with=testing.db) Table("othertable", meta2, autoload_with=testing.db) Table("has_comments", meta2, autoload_with=testing.db) - meta3 = pickle.loads(pickle.dumps(meta2)) + meta3 = picklers.loads(picklers.dumps(meta2)) assert meta3.tables["mytable"] is not t1 return ( meta3.tables["mytable"], meta3.tables["othertable"], meta3.tables["has_comments"], + (False, True), ) - meta.create_all(testing.db) - try: - for test, has_constraints, reflect in ( - (test_to_metadata, True, False), - (test_pickle, True, False), - (test_pickle_via_reflect, False, True), - ): - table_c, table2_c, table3_c = test() - self.assert_tables_equal(table, table_c) - self.assert_tables_equal(table2, table2_c) - assert table is not table_c - assert table.primary_key is not table_c.primary_key - assert ( - list(table2_c.c.myid.foreign_keys)[0].column - is table_c.c.myid - ) - assert ( - list(table2_c.c.myid.foreign_keys)[0].column - is not table.c.myid + assert False + + @testing.requires.check_constraints + def test_copy(self, metadata, copy_fixture, copy_tables_fixture): + + table, table2, table3 = copy_fixture + table_c, table2_c, table3_c, (has_constraints, reflect) = ( + copy_tables_fixture + ) + + self.assert_tables_equal(table, table_c) + self.assert_tables_equal(table2, table2_c) + assert table is not table_c + assert table.primary_key is not table_c.primary_key + assert list(table2_c.c.myid.foreign_keys)[0].column is table_c.c.myid + assert list(table2_c.c.myid.foreign_keys)[0].column is not table.c.myid + assert "x" in str(table_c.c.foo.server_default.arg) + if not reflect: + assert isinstance(table_c.c.myid.default, Sequence) + assert str(table_c.c.foo.server_onupdate.arg) == "q" + assert str(table_c.c.bar.default.arg) == "y" + assert ( + getattr( + table_c.c.bar.onupdate.arg, + "arg", + table_c.c.bar.onupdate.arg, ) - assert "x" in str(table_c.c.foo.server_default.arg) - if not reflect: - assert isinstance(table_c.c.myid.default, Sequence) - assert str(table_c.c.foo.server_onupdate.arg) == "q" - assert str(table_c.c.bar.default.arg) == "y" - assert ( - getattr( - table_c.c.bar.onupdate.arg, - "arg", - table_c.c.bar.onupdate.arg, - ) - == "z" - ) - assert isinstance(table2_c.c.id.default, Sequence) - - # constraints don't get reflected for any dialect right - # now - - if has_constraints: - for c in table_c.c.description.constraints: - if isinstance(c, CheckConstraint): - break - else: - assert False - assert str(c.sqltext) == "description='hi'" - for c in table_c.constraints: - if isinstance(c, UniqueConstraint): - break - else: - assert False - assert c.columns.contains_column(table_c.c.name) - assert not c.columns.contains_column(table.c.name) - - if testing.requires.comment_reflection.enabled: - eq_(table3_c.comment, "table comment") - eq_(table3_c.c.foo.comment, "some column") + == "z" + ) + assert isinstance(table2_c.c.id.default, Sequence) - finally: - meta.drop_all(testing.db) + if testing.requires.unique_constraint_reflection.enabled: + for c in table_c.constraints: + if isinstance(c, UniqueConstraint): + break + else: + for c in table_c.indexes: + break + else: + assert False + + assert c.columns.contains_column(table_c.c.name) + assert not c.columns.contains_column(table.c.name) + + # CHECK constraints don't get reflected for any dialect right + # now + + if has_constraints: + for c in table_c.c.description.constraints: + if isinstance(c, CheckConstraint): + break + else: + assert False + assert str(c.sqltext) == "description='hi'" + + if testing.requires.comment_reflection.enabled: + eq_(table3_c.comment, "table comment") + eq_(table3_c.c.foo.comment, "some column") def test_col_key_fk_parent(self): # test #2643