]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add new variation helper
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 26 Nov 2022 16:03:45 +0000 (11:03 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 26 Nov 2022 22:25:35 +0000 (17:25 -0500)
I'm using a lot of @testing.combinations with either
a boolean True/False, or a series of string names, each indicating
some case to switch on.  I want a descriptive name in the test
run (not True/False) and I don't want to compare strings.

So make a new helper around @combinations that provides an
object interface that has booleans inside of it, prints nicely
in the test output, raises an error if you name the case
incorrectly.

Before:

test/orm/declarative/test_typed_mapping.py::MappedColumnTest::test_separate_name[False-False-both] PASSED
test/orm/declarative/test_typed_mapping.py::MappedColumnTest::test_separate_name[False-False-key] PASSED
test/orm/declarative/test_typed_mapping.py::MappedColumnTest::test_separate_name[False-False-name] PASSED
test/orm/declarative/test_typed_mapping.py::MappedColumnTest::test_separate_name[False-True-both] PASSED
test/orm/declarative/test_typed_mapping.py::MappedColumnTest::test_separate_name[False-True-key] PASSED
test/orm/declarative/test_typed_mapping.py::MappedColumnTest::test_separate_name[False-True-name] PASSED
test/orm/declarative/test_typed_mapping.py::MappedColumnTest::test_separate_name[True-False-both] PASSED
test/orm/declarative/test_typed_mapping.py::MappedColumnTest::test_separate_name[True-False-key] PASSED
test/orm/declarative/test_typed_mapping.py::MappedColumnTest::test_separate_name[True-False-name] PASSED
test/orm/declarative/test_typed_mapping.py::MappedColumnTest::test_separate_name[True-True-both] PASSED
test/orm/declarative/test_typed_mapping.py::MappedColumnTest::test_separate_name[True-True-key] PASSED
test/orm/declarative/test_typed_mapping.py::MappedColumnTest::test_separate_name[True-True-name] PASSED

After:

test/orm/declarative/test_typed_mapping.py::MappedColumnTest::test_separate_name[not_use_add_property-deferred-both] PASSED
test/orm/declarative/test_typed_mapping.py::MappedColumnTest::test_separate_name[not_use_add_property-deferred-key] PASSED
test/orm/declarative/test_typed_mapping.py::MappedColumnTest::test_separate_name[not_use_add_property-deferred-name] PASSED
test/orm/declarative/test_typed_mapping.py::MappedColumnTest::test_separate_name[not_use_add_property-not_deferred-both] PASSED
test/orm/declarative/test_typed_mapping.py::MappedColumnTest::test_separate_name[not_use_add_property-not_deferred-key] PASSED
test/orm/declarative/test_typed_mapping.py::MappedColumnTest::test_separate_name[not_use_add_property-not_deferred-name] PASSED
test/orm/declarative/test_typed_mapping.py::MappedColumnTest::test_separate_name[use_add_property-deferred-both] PASSED
test/orm/declarative/test_typed_mapping.py::MappedColumnTest::test_separate_name[use_add_property-deferred-key] PASSED
test/orm/declarative/test_typed_mapping.py::MappedColumnTest::test_separate_name[use_add_property-deferred-name] PASSED
test/orm/declarative/test_typed_mapping.py::MappedColumnTest::test_separate_name[use_add_property-not_deferred-both] PASSED
test/orm/declarative/test_typed_mapping.py::MappedColumnTest::test_separate_name[use_add_property-not_deferred-key] PASSED
test/orm/declarative/test_typed_mapping.py::MappedColumnTest::test_separate_name[use_add_property-not_deferred-name] PASSED

Change-Id: Idde87632581ee69e0f47360966758583dfd8baab
(cherry picked from commit 3ffa8dccc224d7b7d604bdfb684c437f4cb42f92)

lib/sqlalchemy/testing/__init__.py
lib/sqlalchemy/testing/config.py
test/orm/test_merge.py

index 73b43f04bd43e6e2fcc716edf62a93ade3d7a90a..7d47210452de5c8b03f09993f0d8a9337762311d 100644 (file)
@@ -50,6 +50,7 @@ from .config import db
 from .config import fixture
 from .config import requirements as requires
 from .config import skip_test
+from .config import variation
 from .exclusions import _is_excluded
 from .exclusions import _server_version
 from .exclusions import against as _against
index fc13a165579c77816b05cf3d6d02df5755bd2d6b..7d19b99be5e3927dae212d68d2bcbdf2f9270b06 100644 (file)
@@ -94,6 +94,81 @@ def combinations_list(arg_iterable, **kw):
     return combinations(*arg_iterable, **kw)
 
 
+class _variation_base(object):
+    __slots__ = ("name", "argname")
+
+    def __init__(self, case, argname, case_names):
+        self.name = case
+        self.argname = argname
+        for casename in case_names:
+            setattr(self, casename, casename == case)
+
+    def __bool__(self):
+        return self.name == self.argname
+
+    def __nonzero__(self):
+        return not self.__bool__()
+
+
+def variation(argname, cases):
+    """a helper around testing.combinations that provides a single namespace
+    that can be used as a switch.
+
+    e.g.::
+
+        @testing.variation("querytyp", ["select", "subquery", "legacy_query"])
+        @testing.variation("lazy", ["select", "raise", "raise_on_sql"])
+        def test_thing(
+            self,
+            querytyp,
+            lazy,
+            decl_base
+        ):
+            class Thing(decl_base):
+                __tablename__ = 'thing'
+
+                # use name directly
+                rel = relationship("Rel", lazy=lazy.name)
+
+            # use as a switch
+            if querytyp.select:
+                stmt = select(Thing)
+            elif querytyp.subquery:
+                stmt = select(Thing).subquery()
+            elif querytyp.legacy_query:
+                stmt = Session.query(Thing)
+            else:
+                assert False
+
+
+    The variable provided is a slots object of boolean variables, as well
+    as the name of the case itself under the attribute ".name"
+
+    """
+
+    case_names = [
+        argname if c is True else "not_" + argname if c is False else c
+        for c in cases
+    ]
+
+    typ = type(
+        argname,
+        (_variation_base,),
+        {
+            "__slots__": tuple(case_names),
+        },
+    )
+
+    return combinations(
+        *[
+            (casename, typ(casename, argname, case_names))
+            for casename in case_names
+        ],
+        id_="ia",
+        argnames=argname
+    )
+
+
 def fixture(*arg, **kw):
     return _fixture_functions.fixture(*arg, **kw)
 
index 67307ed6cfd2299f84975cb0b2cdf13791e75213..4f3b4e49561f3b7c7354cdf847707e933f51f722 100644 (file)
@@ -1399,14 +1399,14 @@ class MergeTest(_fixtures.FixtureTest):
         except sa.exc.InvalidRequestError as e:
             assert "load=False option does not support" in str(e)
 
-    @testing.combinations("viewonly", "normal", argnames="viewonly")
-    @testing.combinations("load", "noload", argnames="load")
-    @testing.combinations("select", "raise", "raise_on_sql", argnames="lazy")
-    @testing.combinations(
-        "merge_persistent", "merge_detached", argnames="merge_persistent"
+    @testing.variation("viewonly", ["viewonly", "normal"])
+    @testing.variation("load", ["load", "noload"])
+    @testing.variation("lazy", ["select", "raise", "raise_on_sql"])
+    @testing.variation(
+        "merge_persistent", ["merge_persistent", "merge_detached"]
     )
-    @testing.combinations("detached", "persistent", argnames="detach_original")
-    @testing.combinations("o2m", "m2o", argnames="direction")
+    @testing.variation("detach_original", ["detach", "persistent"])
+    @testing.variation("direction", ["o2m", "m2o"])
     def test_relationship_population_maintained(
         self,
         viewonly,
@@ -1427,8 +1427,8 @@ class MergeTest(_fixtures.FixtureTest):
             properties={
                 "addresses": relationship(
                     Address,
-                    viewonly=viewonly == "viewonly",
-                    lazy=lazy,
+                    viewonly=viewonly.viewonly,
+                    lazy=lazy.name,
                     back_populates="user",
                     order_by=addresses.c.id,
                 )
@@ -1441,8 +1441,8 @@ class MergeTest(_fixtures.FixtureTest):
             properties={
                 "user": relationship(
                     User,
-                    viewonly=viewonly == "viewonly",
-                    lazy=lazy,
+                    viewonly=viewonly.viewonly,
+                    lazy=lazy.name,
                     back_populates="addresses",
                 )
             },
@@ -1458,7 +1458,7 @@ class MergeTest(_fixtures.FixtureTest):
         )
         s.commit()
 
-        if direction == "o2m":
+        if direction.o2m:
             cls_to_merge = User
             obj_to_merge = (
                 s.scalars(select(User).options(joinedload(User.addresses)))
@@ -1467,7 +1467,7 @@ class MergeTest(_fixtures.FixtureTest):
             )
             attrname = "addresses"
 
-        elif direction == "m2o":
+        elif direction.m2o:
             cls_to_merge = Address
             obj_to_merge = (
                 s.scalars(
@@ -1486,21 +1486,21 @@ class MergeTest(_fixtures.FixtureTest):
 
         s2 = Session(testing.db)
 
-        if merge_persistent == "merge_persistent":
+        if merge_persistent.merge_persistent:
             target_persistent = s2.get(cls_to_merge, obj_to_merge.id)  # noqa
 
-        if detach_original == "detach":
+        if detach_original.detach:
             s.expunge(obj_to_merge)
 
         with self.sql_execution_asserter(testing.db) as assert_:
-            merged_object = s2.merge(obj_to_merge, load=load == "load")
+            merged_object = s2.merge(obj_to_merge, load=load.load)
 
         assert_.assert_(
             CountStatements(
                 0
-                if load == "noload"
+                if load.noload
                 else 1
-                if merge_persistent == "merge_persistent"
+                if merge_persistent.merge_persistent
                 else 2
             )
         )
@@ -1508,7 +1508,7 @@ class MergeTest(_fixtures.FixtureTest):
         assert attrname in merged_object.__dict__
 
         with self.sql_execution_asserter(testing.db) as assert_:
-            if direction == "o2m":
+            if direction.o2m:
                 eq_(
                     merged_object.addresses,
                     [
@@ -1516,7 +1516,7 @@ class MergeTest(_fixtures.FixtureTest):
                         for i in range(1, 4)
                     ],
                 )
-            elif direction == "m2o":
+            elif direction.m2o:
                 eq_(merged_object.user, User(id=1, name="u1"))
         assert_.assert_(CountStatements(0))