]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
pass executemany context to _repr_params
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 9 Oct 2019 17:55:19 +0000 (13:55 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 11 Oct 2019 15:11:42 +0000 (11:11 -0400)
Fixed bug where parameter repr as used in logging and error reporting needs
additional context in order to distinguish between a list of parameters for
a single statement and a list of parameter lists, as the "list of lists"
structure could also indicate a single parameter list where the first
parameter itself is a list, such as for an array parameter.   The
engine/connection now passes in an additional boolean indicating how the
parameters should be considered.  The only SQLAlchemy backend that expects
arrays as parameters is that of  psycopg2 which uses pyformat parameters,
so this issue has not been too apparent, however as other drivers that use
positional gain more features it is important that this be supported. It
also eliminates the need for the parameter repr function to guess based on
the parameter structure passed.

Fixes: #4902
Change-Id: I086246ee0eb51484adbefd83e07295fa56576c5f

doc/build/changelog/unreleased_13/4902.rst [new file with mode: 0644]
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/exc.py
lib/sqlalchemy/sql/util.py
lib/sqlalchemy/testing/__init__.py
lib/sqlalchemy/testing/assertions.py
test/base/test_except.py
test/engine/test_logging.py

diff --git a/doc/build/changelog/unreleased_13/4902.rst b/doc/build/changelog/unreleased_13/4902.rst
new file mode 100644 (file)
index 0000000..673077a
--- /dev/null
@@ -0,0 +1,16 @@
+.. change::
+    :tags: bug, engine
+    :tickets: 4902
+
+    Fixed bug where parameter repr as used in logging and error reporting needs
+    additional context in order to distinguish between a list of parameters for
+    a single statement and a list of parameter lists, as the "list of lists"
+    structure could also indicate a single parameter list where the first
+    parameter itself is a list, such as for an array parameter.   The
+    engine/connection now passes in an additional boolean indicating how the
+    parameters should be considered.  The only SQLAlchemy backend that expects
+    arrays as parameters is that of  psycopg2 which uses pyformat parameters,
+    so this issue has not been too apparent, however as other drivers that use
+    positional gain more features it is important that this be supported. It
+    also eliminates the need for the parameter repr function to guess based on
+    the parameter structure passed.
index 27da8c295c4e13b55c764027b44342d450a6b980..95e05be9834f8f710cc5dc30252fbe2a472ce547 100644 (file)
@@ -1223,7 +1223,10 @@ class Connection(Connectable):
             self.engine.logger.info(statement)
             if not self.engine.hide_parameters:
                 self.engine.logger.info(
-                    "%r", sql_util._repr_params(parameters, batches=10)
+                    "%r",
+                    sql_util._repr_params(
+                        parameters, batches=10, ismulti=context.executemany
+                    ),
                 )
             else:
                 self.engine.logger.info(
@@ -1394,6 +1397,9 @@ class Connection(Connectable):
                     self.dialect.dbapi.Error,
                     hide_parameters=self.engine.hide_parameters,
                     dialect=self.dialect,
+                    ismulti=context.executemany
+                    if context is not None
+                    else None,
                 ),
                 exc_info,
             )
@@ -1416,6 +1422,9 @@ class Connection(Connectable):
                     hide_parameters=self.engine.hide_parameters,
                     connection_invalidated=self._is_disconnect,
                     dialect=self.dialect,
+                    ismulti=context.executemany
+                    if context is not None
+                    else None,
                 )
             else:
                 sqlalchemy_exception = None
index 1b3ac7ce2cb9b3dd0d8736694b41d1762bb8571f..79f7868821265d8cbaa4513db0fea6c481a6fc63 100644 (file)
@@ -332,6 +332,8 @@ class StatementError(SQLAlchemyError):
     orig = None
     """The DBAPI exception object."""
 
+    ismulti = None
+
     def __init__(
         self,
         message,
@@ -340,11 +342,13 @@ class StatementError(SQLAlchemyError):
         orig,
         hide_parameters=False,
         code=None,
+        ismulti=None,
     ):
         SQLAlchemyError.__init__(self, message, code=code)
         self.statement = statement
         self.params = params
         self.orig = orig
+        self.ismulti = ismulti
         self.hide_parameters = hide_parameters
         self.detail = []
 
@@ -360,6 +364,7 @@ class StatementError(SQLAlchemyError):
                 self.params,
                 self.orig,
                 self.hide_parameters,
+                self.ismulti,
             ),
         )
 
@@ -381,7 +386,9 @@ class StatementError(SQLAlchemyError):
                         "[SQL parameters hidden due to hide_parameters=True]"
                     )
                 else:
-                    params_repr = util._repr_params(self.params, 10)
+                    params_repr = util._repr_params(
+                        self.params, 10, ismulti=self.ismulti
+                    )
                     details.append("[parameters: %r]" % params_repr)
         code_str = self._code_str()
         if code_str:
@@ -424,6 +431,7 @@ class DBAPIError(StatementError):
         hide_parameters=False,
         connection_invalidated=False,
         dialect=None,
+        ismulti=None,
     ):
         # Don't ever wrap these, just return them directly as if
         # DBAPIError didn't exist.
@@ -448,6 +456,7 @@ class DBAPIError(StatementError):
                     orig,
                     hide_parameters=hide_parameters,
                     code=orig.code,
+                    ismulti=ismulti,
                 )
             elif not isinstance(orig, dbapi_base_err) and statement:
                 return StatementError(
@@ -461,6 +470,7 @@ class DBAPIError(StatementError):
                     params,
                     orig,
                     hide_parameters=hide_parameters,
+                    ismulti=ismulti,
                 )
 
             glob = globals()
@@ -481,6 +491,7 @@ class DBAPIError(StatementError):
             connection_invalidated=connection_invalidated,
             hide_parameters=hide_parameters,
             code=cls.code,
+            ismulti=ismulti,
         )
 
     def __reduce__(self):
@@ -492,6 +503,7 @@ class DBAPIError(StatementError):
                 self.orig,
                 self.hide_parameters,
                 self.connection_invalidated,
+                self.ismulti,
             ),
         )
 
@@ -503,6 +515,7 @@ class DBAPIError(StatementError):
         hide_parameters=False,
         connection_invalidated=False,
         code=None,
+        ismulti=None,
     ):
         try:
             text = str(orig)
@@ -517,6 +530,7 @@ class DBAPIError(StatementError):
             orig,
             hide_parameters,
             code=code,
+            ismulti=ismulti,
         )
         self.connection_invalidated = connection_invalidated
 
index 5aeed0c1c0cedc65c24927a4dd3affe8cccbed54..e109852a2b09e6628e1a319c3b1eaea1d75a5c04 100644 (file)
@@ -464,31 +464,29 @@ class _repr_params(_repr_base):
 
     """
 
-    __slots__ = "params", "batches"
+    __slots__ = "params", "batches", "ismulti"
 
-    def __init__(self, params, batches, max_chars=300):
+    def __init__(self, params, batches, max_chars=300, ismulti=None):
         self.params = params
+        self.ismulti = ismulti
         self.batches = batches
         self.max_chars = max_chars
 
     def __repr__(self):
+        if self.ismulti is None:
+            return self.trunc(self.params)
+
         if isinstance(self.params, list):
             typ = self._LIST
-            ismulti = self.params and isinstance(
-                self.params[0], (list, dict, tuple)
-            )
+
         elif isinstance(self.params, tuple):
             typ = self._TUPLE
-            ismulti = self.params and isinstance(
-                self.params[0], (list, dict, tuple)
-            )
         elif isinstance(self.params, dict):
             typ = self._DICT
-            ismulti = False
         else:
             return self.trunc(self.params)
 
-        if ismulti and len(self.params) > self.batches:
+        if self.ismulti and len(self.params) > self.batches:
             msg = " ... displaying %i of %i total bound parameter sets ... "
             return " ".join(
                 (
@@ -499,7 +497,7 @@ class _repr_params(_repr_base):
                     self._repr_multi(self.params[-2:], typ)[1:],
                 )
             )
-        elif ismulti:
+        elif self.ismulti:
             return self._repr_multi(self.params, typ)
         else:
             return self._repr_params(self.params, typ)
index 090c8488a5910554948ff5c9285e814861bbc2ef..2b8158fbb1bd84ad0870ef18b20d1fc5d2e574a0 100644 (file)
@@ -10,6 +10,7 @@ from . import config  # noqa
 from . import mock  # noqa
 from .assertions import assert_raises  # noqa
 from .assertions import assert_raises_message  # noqa
+from .assertions import assert_raises_return  # noqa
 from .assertions import AssertsCompiledSQL  # noqa
 from .assertions import AssertsExecutionResults  # noqa
 from .assertions import ComparesTables  # noqa
index 819fedcc772b016734085ecbcb5f3abe5f98f8fa..f057ae37b283a9d7debef5b27f4fa5ad35eff1de 100644 (file)
@@ -307,6 +307,20 @@ def assert_raises(except_cls, callable_, *args, **kw):
     assert success, "Callable did not raise an exception"
 
 
+def assert_raises_return(except_cls, callable_, *args, **kw):
+    ret_err = None
+    try:
+        callable_(*args, **kw)
+        success = False
+    except except_cls as err:
+        success = True
+        ret_err = err
+
+    # assert outside the block so it works for AssertionError too !
+    assert success, "Callable did not raise an exception"
+    return ret_err
+
+
 def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
     try:
         callable_(*args, **kwargs)
index 7dcfbb1f0923ab97d77fc93d06a6da6c7bf72a22..be8c85a6439b3668df2aef7a2b9a7cb0bf99abf3 100644 (file)
@@ -241,6 +241,7 @@ class WrapTest(fixtures.TestBase):
                 ],
                 OperationalError(),
                 DatabaseError,
+                ismulti=True,
             )
         except sa_exceptions.DBAPIError as exc:
             eq_(
@@ -288,6 +289,7 @@ class WrapTest(fixtures.TestBase):
                 ],
                 OperationalError(),
                 DatabaseError,
+                ismulti=True,
             )
         except sa_exceptions.DBAPIError as exc:
             eq_(
index d55f4249a0f34d061bf3222af4fb93692df6d308..14b49e5e89af0b5292b815e6c09ef5c6db82dc30 100644 (file)
@@ -9,7 +9,9 @@ from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import Table
 from sqlalchemy import util
+from sqlalchemy.sql import util as sql_util
 from sqlalchemy.testing import assert_raises_message
+from sqlalchemy.testing import assert_raises_return
 from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import eq_regex
@@ -38,7 +40,7 @@ class LogParamsTest(fixtures.TestBase):
         for log in [logging.getLogger("sqlalchemy.engine")]:
             log.removeHandler(self.buf)
 
-    def test_log_large_dict(self):
+    def test_log_large_list_of_dict(self):
         self.eng.execute(
             "INSERT INTO foo (data) values (:data)",
             [{"data": str(i)} for i in range(100)],
@@ -51,6 +53,21 @@ class LogParamsTest(fixtures.TestBase):
             "parameter sets ...  {'data': '98'}, {'data': '99'}]",
         )
 
+    def test_repr_params_large_list_of_dict(self):
+        eq_(
+            repr(
+                sql_util._repr_params(
+                    [{"data": str(i)} for i in range(100)],
+                    batches=10,
+                    ismulti=True,
+                )
+            ),
+            "[{'data': '0'}, {'data': '1'}, {'data': '2'}, {'data': '3'}, "
+            "{'data': '4'}, {'data': '5'}, {'data': '6'}, {'data': '7'}"
+            "  ... displaying 10 of 100 total bound "
+            "parameter sets ...  {'data': '98'}, {'data': '99'}]",
+        )
+
     def test_log_no_parameters(self):
         self.no_param_engine.execute(
             "INSERT INTO foo (data) values (:data)",
@@ -61,7 +78,7 @@ class LogParamsTest(fixtures.TestBase):
             "[SQL parameters hidden due to hide_parameters=True]",
         )
 
-    def test_log_large_list(self):
+    def test_log_large_list_of_tuple(self):
         self.eng.execute(
             "INSERT INTO foo (data) values (?)",
             [(str(i),) for i in range(100)],
@@ -73,6 +90,119 @@ class LogParamsTest(fixtures.TestBase):
             "bound parameter sets ...  ('98',), ('99',)]",
         )
 
+    def test_log_positional_array(self):
+        with self.eng.connect() as conn:
+            exc_info = assert_raises_return(
+                tsa.exc.DBAPIError,
+                conn.execute,
+                tsa.text("SELECT * FROM foo WHERE id IN :foo AND bar=:bar"),
+                {"foo": [1, 2, 3], "bar": "hi"},
+            )
+
+            assert (
+                "[SQL: SELECT * FROM foo WHERE id IN ? AND bar=?]\n"
+                "[parameters: ([1, 2, 3], 'hi')]\n" in str(exc_info)
+            )
+
+            eq_(self.buf.buffer[1].message, "([1, 2, 3], 'hi')")
+
+    def test_repr_params_positional_array(self):
+        eq_(
+            repr(
+                sql_util._repr_params(
+                    [[1, 2, 3], 5], batches=10, ismulti=False
+                )
+            ),
+            "[[1, 2, 3], 5]",
+        )
+
+    def test_repr_params_unknown_list(self):
+        # not known if given multiparams or not.   repr params with
+        # straight truncation
+        eq_(
+            repr(
+                sql_util._repr_params(
+                    [[i for i in range(300)], 5], batches=10, max_chars=80
+                )
+            ),
+            "[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,  ... "
+            "(1315 characters truncated) ... , 293, 294, 295, 296, "
+            "297, 298, 299], 5]",
+        )
+
+    def test_repr_params_positional_list(self):
+        # given non-multi-params in a list.   repr params with
+        # per-element truncation, mostly does the exact same thing
+        eq_(
+            repr(
+                sql_util._repr_params(
+                    [[i for i in range(300)], 5],
+                    batches=10,
+                    max_chars=80,
+                    ismulti=False,
+                )
+            ),
+            "[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 1 ... "
+            "(1310 characters truncated) ...  "
+            "292, 293, 294, 295, 296, 297, 298, 299], 5]",
+        )
+
+    def test_repr_params_named_dict(self):
+        # given non-multi-params in a list.   repr params with
+        # per-element truncation, mostly does the exact same thing
+        params = {"key_%s" % i: i for i in range(10)}
+        eq_(
+            repr(
+                sql_util._repr_params(
+                    params, batches=10, max_chars=80, ismulti=False
+                )
+            ),
+            repr(params),
+        )
+
+    def test_repr_params_ismulti_named_dict(self):
+        # given non-multi-params in a list.   repr params with
+        # per-element truncation, mostly does the exact same thing
+        param = {"key_%s" % i: i for i in range(10)}
+        eq_(
+            repr(
+                sql_util._repr_params(
+                    [param for j in range(50)],
+                    batches=5,
+                    max_chars=80,
+                    ismulti=True,
+                )
+            ),
+            "[%(param)r, %(param)r, %(param)r  ... "
+            "displaying 5 of 50 total bound parameter sets ...  "
+            "%(param)r, %(param)r]" % {"param": param},
+        )
+
+    def test_repr_params_ismulti_list(self):
+        # given multi-params in a list.   repr params with
+        # per-element truncation, mostly does the exact same thing
+        eq_(
+            repr(
+                sql_util._repr_params(
+                    [
+                        [[i for i in range(300)], 5],
+                        [[i for i in range(300)], 5],
+                        [[i for i in range(300)], 5],
+                    ],
+                    batches=10,
+                    max_chars=80,
+                    ismulti=True,
+                )
+            ),
+            "[[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 1 ... "
+            "(1310 characters truncated) ...  292, 293, 294, 295, 296, 297, "
+            "298, 299], 5], [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 1 ... "
+            "(1310 characters truncated) ...  292, 293, 294, 295, 296, 297, "
+            "298, 299], 5], [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 1 ... "
+            "(1310 characters truncated) ...  292, 293, 294, 295, 296, 297, "
+            "298, 299], 5]]",
+        )
+
     def test_log_large_parameter_single(self):
         import random