]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Rewrite pool reset_on_return parsing using a util function
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 11 Aug 2019 19:24:13 +0000 (15:24 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 12 Aug 2019 02:34:04 +0000 (22:34 -0400)
Choosing a util.symbol() based on a user parameter is about to have
another use case added as part of #4623, so add a generalized solution
ahead of it.

Change-Id: I420631f81af2ffc655995b9cce9ff2ac618c16d7
(cherry picked from commit 2a079cdc76b3a0f5b4f37299d280d328586e2f7e)

lib/sqlalchemy/pool/base.py
lib/sqlalchemy/util/langhelpers.py
test/base/test_utils.py

index f5585c6519eb9da1236565e0ef14fe80b0857a7b..410df47f1ad0273279cf653c67756f85c0524a7d 100644 (file)
@@ -204,16 +204,16 @@ class Pool(log.Identified):
         self._invalidate_time = 0
         self._use_threadlocal = use_threadlocal
         self._pre_ping = pre_ping
-        if reset_on_return in ("rollback", True, reset_rollback):
-            self._reset_on_return = reset_rollback
-        elif reset_on_return in ("none", None, False, reset_none):
-            self._reset_on_return = reset_none
-        elif reset_on_return in ("commit", reset_commit):
-            self._reset_on_return = reset_commit
-        else:
-            raise exc.ArgumentError(
-                "Invalid value for 'reset_on_return': %r" % reset_on_return
-            )
+        self._reset_on_return = util.symbol.parse_user_argument(
+            reset_on_return,
+            {
+                reset_rollback: ["rollback", True],
+                reset_none: ["none", None, False],
+                reset_commit: ["commit"],
+            },
+            "reset_on_return",
+            resolve_symbol_names=False,
+        )
 
         self.echo = echo
 
index a5c1f4803465efeedd2697998bab047bee706aa8..b9ce2ebea1dd429b13b699d215cb9df4e6fd22db 100644 (file)
@@ -1347,6 +1347,41 @@ class symbol(object):
         finally:
             symbol._lock.release()
 
+    @classmethod
+    def parse_user_argument(
+        cls, arg, choices, name, resolve_symbol_names=False
+    ):
+        """Given a user parameter, parse the parameter into a chosen symbol.
+
+        The user argument can be a string name that matches the name of a
+        symbol, or the symbol object itself, or any number of alternate choices
+        such as True/False/ None etc.
+
+        :param arg: the user argument.
+        :param choices: dictionary of symbol object to list of possible
+         entries.
+        :param name: name of the argument.   Used in an :class:`.ArgumentError`
+         that is raised if the parameter doesn't match any available argument.
+        :param resolve_symbol_names: include the name of each symbol as a valid
+         entry.
+
+        """
+        # note using hash lookup is tricky here because symbol's `__hash__`
+        # is its int value which we don't want included in the lookup
+        # explicitly, so we iterate and compare each.
+        for sym, choice in choices.items():
+            if arg is sym:
+                return sym
+            elif resolve_symbol_names and arg == sym.name:
+                return sym
+            elif arg in choice:
+                return sym
+
+        if arg is None:
+            return None
+
+        raise exc.ArgumentError("Invalid value for '%s': %r" % (name, arg))
+
 
 _creation_order = 1
 
index fea34cf8da948d3ed1e979b8d3286365b95e57f2..a69c44ded120329ea37c952bbb3fdc106194ff84 100644 (file)
@@ -1868,6 +1868,69 @@ class SymbolTest(fixtures.TestBase):
         assert not (sym1 | sym2) & (sym3 | sym4)
         assert (sym1 | sym2) & (sym2 | sym4)
 
+    def test_parser(self):
+        sym1 = util.symbol("sym1", canonical=1)
+        sym2 = util.symbol("sym2", canonical=2)
+        sym3 = util.symbol("sym3", canonical=4)
+        sym4 = util.symbol("sym4", canonical=8)
+
+        lookup_one = {sym1: [], sym2: [True], sym3: [False], sym4: [None]}
+        lookup_two = {sym1: [], sym2: [True], sym3: [False]}
+        lookup_three = {sym1: [], sym2: ["symbol2"], sym3: []}
+
+        is_(
+            util.symbol.parse_user_argument(
+                "sym2", lookup_one, "some_name", resolve_symbol_names=True
+            ),
+            sym2,
+        )
+
+        assert_raises_message(
+            exc.ArgumentError,
+            "Invalid value for 'some_name': 'sym2'",
+            util.symbol.parse_user_argument,
+            "sym2",
+            lookup_one,
+            "some_name",
+        )
+        is_(
+            util.symbol.parse_user_argument(
+                True, lookup_one, "some_name", resolve_symbol_names=False
+            ),
+            sym2,
+        )
+
+        is_(
+            util.symbol.parse_user_argument(sym2, lookup_one, "some_name"),
+            sym2,
+        )
+
+        is_(
+            util.symbol.parse_user_argument(None, lookup_one, "some_name"),
+            sym4,
+        )
+
+        is_(
+            util.symbol.parse_user_argument(None, lookup_two, "some_name"),
+            None,
+        )
+
+        is_(
+            util.symbol.parse_user_argument(
+                "symbol2", lookup_three, "some_name"
+            ),
+            sym2,
+        )
+
+        assert_raises_message(
+            exc.ArgumentError,
+            "Invalid value for 'some_name': 'foo'",
+            util.symbol.parse_user_argument,
+            "foo",
+            lookup_three,
+            "some_name",
+        )
+
 
 class _Py3KFixtures(object):
     pass