From: Mike Bayer Date: Sun, 11 Aug 2019 19:24:13 +0000 (-0400) Subject: Rewrite pool reset_on_return parsing using a util function X-Git-Tag: rel_1_3_7~5 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ff37cd87ef60c801cd068b8d7834947d196f9040;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Rewrite pool reset_on_return parsing using a util function Choosing a util.symbol() based on a user parameter is about to have another use case added as part of #4623, so add a generalized solution ahead of it. Change-Id: I420631f81af2ffc655995b9cce9ff2ac618c16d7 (cherry picked from commit 2a079cdc76b3a0f5b4f37299d280d328586e2f7e) --- diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py index f5585c6519..410df47f1a 100644 --- a/lib/sqlalchemy/pool/base.py +++ b/lib/sqlalchemy/pool/base.py @@ -204,16 +204,16 @@ class Pool(log.Identified): self._invalidate_time = 0 self._use_threadlocal = use_threadlocal self._pre_ping = pre_ping - if reset_on_return in ("rollback", True, reset_rollback): - self._reset_on_return = reset_rollback - elif reset_on_return in ("none", None, False, reset_none): - self._reset_on_return = reset_none - elif reset_on_return in ("commit", reset_commit): - self._reset_on_return = reset_commit - else: - raise exc.ArgumentError( - "Invalid value for 'reset_on_return': %r" % reset_on_return - ) + self._reset_on_return = util.symbol.parse_user_argument( + reset_on_return, + { + reset_rollback: ["rollback", True], + reset_none: ["none", None, False], + reset_commit: ["commit"], + }, + "reset_on_return", + resolve_symbol_names=False, + ) self.echo = echo diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index a5c1f48034..b9ce2ebea1 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -1347,6 +1347,41 @@ class symbol(object): finally: symbol._lock.release() + @classmethod + def parse_user_argument( + cls, arg, choices, name, resolve_symbol_names=False + ): + """Given a user parameter, parse the parameter into a chosen symbol. + + The user argument can be a string name that matches the name of a + symbol, or the symbol object itself, or any number of alternate choices + such as True/False/ None etc. + + :param arg: the user argument. + :param choices: dictionary of symbol object to list of possible + entries. + :param name: name of the argument. Used in an :class:`.ArgumentError` + that is raised if the parameter doesn't match any available argument. + :param resolve_symbol_names: include the name of each symbol as a valid + entry. + + """ + # note using hash lookup is tricky here because symbol's `__hash__` + # is its int value which we don't want included in the lookup + # explicitly, so we iterate and compare each. + for sym, choice in choices.items(): + if arg is sym: + return sym + elif resolve_symbol_names and arg == sym.name: + return sym + elif arg in choice: + return sym + + if arg is None: + return None + + raise exc.ArgumentError("Invalid value for '%s': %r" % (name, arg)) + _creation_order = 1 diff --git a/test/base/test_utils.py b/test/base/test_utils.py index fea34cf8da..a69c44ded1 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -1868,6 +1868,69 @@ class SymbolTest(fixtures.TestBase): assert not (sym1 | sym2) & (sym3 | sym4) assert (sym1 | sym2) & (sym2 | sym4) + def test_parser(self): + sym1 = util.symbol("sym1", canonical=1) + sym2 = util.symbol("sym2", canonical=2) + sym3 = util.symbol("sym3", canonical=4) + sym4 = util.symbol("sym4", canonical=8) + + lookup_one = {sym1: [], sym2: [True], sym3: [False], sym4: [None]} + lookup_two = {sym1: [], sym2: [True], sym3: [False]} + lookup_three = {sym1: [], sym2: ["symbol2"], sym3: []} + + is_( + util.symbol.parse_user_argument( + "sym2", lookup_one, "some_name", resolve_symbol_names=True + ), + sym2, + ) + + assert_raises_message( + exc.ArgumentError, + "Invalid value for 'some_name': 'sym2'", + util.symbol.parse_user_argument, + "sym2", + lookup_one, + "some_name", + ) + is_( + util.symbol.parse_user_argument( + True, lookup_one, "some_name", resolve_symbol_names=False + ), + sym2, + ) + + is_( + util.symbol.parse_user_argument(sym2, lookup_one, "some_name"), + sym2, + ) + + is_( + util.symbol.parse_user_argument(None, lookup_one, "some_name"), + sym4, + ) + + is_( + util.symbol.parse_user_argument(None, lookup_two, "some_name"), + None, + ) + + is_( + util.symbol.parse_user_argument( + "symbol2", lookup_three, "some_name" + ), + sym2, + ) + + assert_raises_message( + exc.ArgumentError, + "Invalid value for 'some_name': 'foo'", + util.symbol.parse_user_argument, + "foo", + lookup_three, + "some_name", + ) + class _Py3KFixtures(object): pass