]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
pass executemany context to _repr_params
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 9 Oct 2019 17:55:19 +0000 (13:55 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 11 Oct 2019 15:11:53 +0000 (11:11 -0400)
Fixed bug where parameter repr as used in logging and error reporting needs
additional context in order to distinguish between a list of parameters for
a single statement and a list of parameter lists, as the "list of lists"
structure could also indicate a single parameter list where the first
parameter itself is a list, such as for an array parameter.   The
engine/connection now passes in an additional boolean indicating how the
parameters should be considered.  The only SQLAlchemy backend that expects
arrays as parameters is that of  psycopg2 which uses pyformat parameters,
so this issue has not been too apparent, however as other drivers that use
positional gain more features it is important that this be supported. It
also eliminates the need for the parameter repr function to guess based on
the parameter structure passed.

Fixes: #4902
Change-Id: I086246ee0eb51484adbefd83e07295fa56576c5f
(cherry picked from commit 9488480abea15298ded6996aa13b42edf134e467)

doc/build/changelog/unreleased_13/4902.rst [new file with mode: 0644]
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/exc.py
lib/sqlalchemy/sql/util.py
lib/sqlalchemy/testing/__init__.py
lib/sqlalchemy/testing/assertions.py
test/base/test_except.py
test/engine/test_logging.py

diff --git a/doc/build/changelog/unreleased_13/4902.rst b/doc/build/changelog/unreleased_13/4902.rst
new file mode 100644 (file)
index 0000000..673077a
--- /dev/null
@@ -0,0 +1,16 @@
+.. change::
+    :tags: bug, engine
+    :tickets: 4902
+
+    Fixed bug where parameter repr as used in logging and error reporting needs
+    additional context in order to distinguish between a list of parameters for
+    a single statement and a list of parameter lists, as the "list of lists"
+    structure could also indicate a single parameter list where the first
+    parameter itself is a list, such as for an array parameter.   The
+    engine/connection now passes in an additional boolean indicating how the
+    parameters should be considered.  The only SQLAlchemy backend that expects
+    arrays as parameters is that of  psycopg2 which uses pyformat parameters,
+    so this issue has not been too apparent, however as other drivers that use
+    positional gain more features it is important that this be supported. It
+    also eliminates the need for the parameter repr function to guess based on
+    the parameter structure passed.
index fd9b11edbc09d0a7411c0c782b48f1a8458d2be1..a3d9a22b387ff4e80e0b63a1f92d5f6831939d59 100644 (file)
@@ -1209,7 +1209,10 @@ class Connection(Connectable):
             self.engine.logger.info(statement)
             if not self.engine.hide_parameters:
                 self.engine.logger.info(
-                    "%r", sql_util._repr_params(parameters, batches=10)
+                    "%r",
+                    sql_util._repr_params(
+                        parameters, batches=10, ismulti=context.executemany
+                    ),
                 )
             else:
                 self.engine.logger.info(
@@ -1380,6 +1383,9 @@ class Connection(Connectable):
                     self.dialect.dbapi.Error,
                     hide_parameters=self.engine.hide_parameters,
                     dialect=self.dialect,
+                    ismulti=context.executemany
+                    if context is not None
+                    else None,
                 ),
                 exc_info,
             )
@@ -1402,6 +1408,9 @@ class Connection(Connectable):
                     hide_parameters=self.engine.hide_parameters,
                     connection_invalidated=self._is_disconnect,
                     dialect=self.dialect,
+                    ismulti=context.executemany
+                    if context is not None
+                    else None,
                 )
             else:
                 sqlalchemy_exception = None
index efee58d991643d9cdf2356c7b31debfe0d441c81..58e9fb78d4a757874b96879113c6c9f862b9ec7c 100644 (file)
@@ -332,6 +332,8 @@ class StatementError(SQLAlchemyError):
     orig = None
     """The DBAPI exception object."""
 
+    ismulti = None
+
     def __init__(
         self,
         message,
@@ -340,11 +342,13 @@ class StatementError(SQLAlchemyError):
         orig,
         hide_parameters=False,
         code=None,
+        ismulti=None,
     ):
         SQLAlchemyError.__init__(self, message, code=code)
         self.statement = statement
         self.params = params
         self.orig = orig
+        self.ismulti = ismulti
         self.hide_parameters = hide_parameters
         self.detail = []
 
@@ -360,6 +364,7 @@ class StatementError(SQLAlchemyError):
                 self.params,
                 self.orig,
                 self.hide_parameters,
+                self.ismulti,
             ),
         )
 
@@ -381,7 +386,9 @@ class StatementError(SQLAlchemyError):
                         "[SQL parameters hidden due to hide_parameters=True]"
                     )
                 else:
-                    params_repr = util._repr_params(self.params, 10)
+                    params_repr = util._repr_params(
+                        self.params, 10, ismulti=self.ismulti
+                    )
                     details.append("[parameters: %r]" % params_repr)
         code_str = self._code_str()
         if code_str:
@@ -424,6 +431,7 @@ class DBAPIError(StatementError):
         hide_parameters=False,
         connection_invalidated=False,
         dialect=None,
+        ismulti=None,
     ):
         # Don't ever wrap these, just return them directly as if
         # DBAPIError didn't exist.
@@ -448,6 +456,7 @@ class DBAPIError(StatementError):
                     orig,
                     hide_parameters=hide_parameters,
                     code=orig.code,
+                    ismulti=ismulti,
                 )
             elif not isinstance(orig, dbapi_base_err) and statement:
                 return StatementError(
@@ -461,6 +470,7 @@ class DBAPIError(StatementError):
                     params,
                     orig,
                     hide_parameters=hide_parameters,
+                    ismulti=ismulti,
                 )
 
             glob = globals()
@@ -481,6 +491,7 @@ class DBAPIError(StatementError):
             connection_invalidated=connection_invalidated,
             hide_parameters=hide_parameters,
             code=cls.code,
+            ismulti=ismulti,
         )
 
     def __reduce__(self):
@@ -492,6 +503,7 @@ class DBAPIError(StatementError):
                 self.orig,
                 self.hide_parameters,
                 self.connection_invalidated,
+                self.ismulti,
             ),
         )
 
@@ -503,6 +515,7 @@ class DBAPIError(StatementError):
         hide_parameters=False,
         connection_invalidated=False,
         code=None,
+        ismulti=None,
     ):
         try:
             text = str(orig)
@@ -517,6 +530,7 @@ class DBAPIError(StatementError):
             orig,
             hide_parameters,
             code=code,
+            ismulti=ismulti,
         )
         self.connection_invalidated = connection_invalidated
 
index d90b3f158bf562663a926e79bfa52de42130a420..75c00218d134cd9d3019f5b0fa2adda48247025f 100644 (file)
@@ -468,31 +468,29 @@ class _repr_params(_repr_base):
 
     """
 
-    __slots__ = "params", "batches"
+    __slots__ = "params", "batches", "ismulti"
 
-    def __init__(self, params, batches, max_chars=300):
+    def __init__(self, params, batches, max_chars=300, ismulti=None):
         self.params = params
+        self.ismulti = ismulti
         self.batches = batches
         self.max_chars = max_chars
 
     def __repr__(self):
+        if self.ismulti is None:
+            return self.trunc(self.params)
+
         if isinstance(self.params, list):
             typ = self._LIST
-            ismulti = self.params and isinstance(
-                self.params[0], (list, dict, tuple)
-            )
+
         elif isinstance(self.params, tuple):
             typ = self._TUPLE
-            ismulti = self.params and isinstance(
-                self.params[0], (list, dict, tuple)
-            )
         elif isinstance(self.params, dict):
             typ = self._DICT
-            ismulti = False
         else:
             return self.trunc(self.params)
 
-        if ismulti and len(self.params) > self.batches:
+        if self.ismulti and len(self.params) > self.batches:
             msg = " ... displaying %i of %i total bound parameter sets ... "
             return " ".join(
                 (
@@ -503,7 +501,7 @@ class _repr_params(_repr_base):
                     self._repr_multi(self.params[-2:], typ)[1:],
                 )
             )
-        elif ismulti:
+        elif self.ismulti:
             return self._repr_multi(self.params, typ)
         else:
             return self._repr_params(self.params, typ)
index c52e9a76f1a1a246ed424f6fe9acb5db9ef21672..ecfef9471b69461d7e56c738f8a581ab70cad92e 100644 (file)
@@ -10,6 +10,7 @@ from . import config  # noqa
 from . import mock  # noqa
 from .assertions import assert_raises  # noqa
 from .assertions import assert_raises_message  # noqa
+from .assertions import assert_raises_return  # noqa
 from .assertions import AssertsCompiledSQL  # noqa
 from .assertions import AssertsExecutionResults  # noqa
 from .assertions import ComparesTables  # noqa
index d8038e225cf5cd139400dd3ef3aca6fb500fada9..0bd83de4a40addf63b3b24db2a59cdf2f6a4e4ab 100644 (file)
@@ -303,6 +303,20 @@ def assert_raises(except_cls, callable_, *args, **kw):
     assert success, "Callable did not raise an exception"
 
 
+def assert_raises_return(except_cls, callable_, *args, **kw):
+    ret_err = None
+    try:
+        callable_(*args, **kw)
+        success = False
+    except except_cls as err:
+        success = True
+        ret_err = err
+
+    # assert outside the block so it works for AssertionError too !
+    assert success, "Callable did not raise an exception"
+    return ret_err
+
+
 def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
     try:
         callable_(*args, **kwargs)
index 7dcfbb1f0923ab97d77fc93d06a6da6c7bf72a22..be8c85a6439b3668df2aef7a2b9a7cb0bf99abf3 100644 (file)
@@ -241,6 +241,7 @@ class WrapTest(fixtures.TestBase):
                 ],
                 OperationalError(),
                 DatabaseError,
+                ismulti=True,
             )
         except sa_exceptions.DBAPIError as exc:
             eq_(
@@ -288,6 +289,7 @@ class WrapTest(fixtures.TestBase):
                 ],
                 OperationalError(),
                 DatabaseError,
+                ismulti=True,
             )
         except sa_exceptions.DBAPIError as exc:
             eq_(
index d55f4249a0f34d061bf3222af4fb93692df6d308..14b49e5e89af0b5292b815e6c09ef5c6db82dc30 100644 (file)
@@ -9,7 +9,9 @@ from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import Table
 from sqlalchemy import util
+from sqlalchemy.sql import util as sql_util
 from sqlalchemy.testing import assert_raises_message
+from sqlalchemy.testing import assert_raises_return
 from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import eq_regex
@@ -38,7 +40,7 @@ class LogParamsTest(fixtures.TestBase):
         for log in [logging.getLogger("sqlalchemy.engine")]:
             log.removeHandler(self.buf)
 
-    def test_log_large_dict(self):
+    def test_log_large_list_of_dict(self):
         self.eng.execute(
             "INSERT INTO foo (data) values (:data)",
             [{"data": str(i)} for i in range(100)],
@@ -51,6 +53,21 @@ class LogParamsTest(fixtures.TestBase):
             "parameter sets ...  {'data': '98'}, {'data': '99'}]",
         )
 
+    def test_repr_params_large_list_of_dict(self):
+        eq_(
+            repr(
+                sql_util._repr_params(
+                    [{"data": str(i)} for i in range(100)],
+                    batches=10,
+                    ismulti=True,
+                )
+            ),
+            "[{'data': '0'}, {'data': '1'}, {'data': '2'}, {'data': '3'}, "
+            "{'data': '4'}, {'data': '5'}, {'data': '6'}, {'data': '7'}"
+            "  ... displaying 10 of 100 total bound "
+            "parameter sets ...  {'data': '98'}, {'data': '99'}]",
+        )
+
     def test_log_no_parameters(self):
         self.no_param_engine.execute(
             "INSERT INTO foo (data) values (:data)",
@@ -61,7 +78,7 @@ class LogParamsTest(fixtures.TestBase):
             "[SQL parameters hidden due to hide_parameters=True]",
         )
 
-    def test_log_large_list(self):
+    def test_log_large_list_of_tuple(self):
         self.eng.execute(
             "INSERT INTO foo (data) values (?)",
             [(str(i),) for i in range(100)],
@@ -73,6 +90,119 @@ class LogParamsTest(fixtures.TestBase):
             "bound parameter sets ...  ('98',), ('99',)]",
         )
 
+    def test_log_positional_array(self):
+        with self.eng.connect() as conn:
+            exc_info = assert_raises_return(
+                tsa.exc.DBAPIError,
+                conn.execute,
+                tsa.text("SELECT * FROM foo WHERE id IN :foo AND bar=:bar"),
+                {"foo": [1, 2, 3], "bar": "hi"},
+            )
+
+            assert (
+                "[SQL: SELECT * FROM foo WHERE id IN ? AND bar=?]\n"
+                "[parameters: ([1, 2, 3], 'hi')]\n" in str(exc_info)
+            )
+
+            eq_(self.buf.buffer[1].message, "([1, 2, 3], 'hi')")
+
+    def test_repr_params_positional_array(self):
+        eq_(
+            repr(
+                sql_util._repr_params(
+                    [[1, 2, 3], 5], batches=10, ismulti=False
+                )
+            ),
+            "[[1, 2, 3], 5]",
+        )
+
+    def test_repr_params_unknown_list(self):
+        # not known if given multiparams or not.   repr params with
+        # straight truncation
+        eq_(
+            repr(
+                sql_util._repr_params(
+                    [[i for i in range(300)], 5], batches=10, max_chars=80
+                )
+            ),
+            "[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,  ... "
+            "(1315 characters truncated) ... , 293, 294, 295, 296, "
+            "297, 298, 299], 5]",
+        )
+
+    def test_repr_params_positional_list(self):
+        # given non-multi-params in a list.   repr params with
+        # per-element truncation, mostly does the exact same thing
+        eq_(
+            repr(
+                sql_util._repr_params(
+                    [[i for i in range(300)], 5],
+                    batches=10,
+                    max_chars=80,
+                    ismulti=False,
+                )
+            ),
+            "[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 1 ... "
+            "(1310 characters truncated) ...  "
+            "292, 293, 294, 295, 296, 297, 298, 299], 5]",
+        )
+
+    def test_repr_params_named_dict(self):
+        # given non-multi-params in a list.   repr params with
+        # per-element truncation, mostly does the exact same thing
+        params = {"key_%s" % i: i for i in range(10)}
+        eq_(
+            repr(
+                sql_util._repr_params(
+                    params, batches=10, max_chars=80, ismulti=False
+                )
+            ),
+            repr(params),
+        )
+
+    def test_repr_params_ismulti_named_dict(self):
+        # given non-multi-params in a list.   repr params with
+        # per-element truncation, mostly does the exact same thing
+        param = {"key_%s" % i: i for i in range(10)}
+        eq_(
+            repr(
+                sql_util._repr_params(
+                    [param for j in range(50)],
+                    batches=5,
+                    max_chars=80,
+                    ismulti=True,
+                )
+            ),
+            "[%(param)r, %(param)r, %(param)r  ... "
+            "displaying 5 of 50 total bound parameter sets ...  "
+            "%(param)r, %(param)r]" % {"param": param},
+        )
+
+    def test_repr_params_ismulti_list(self):
+        # given multi-params in a list.   repr params with
+        # per-element truncation, mostly does the exact same thing
+        eq_(
+            repr(
+                sql_util._repr_params(
+                    [
+                        [[i for i in range(300)], 5],
+                        [[i for i in range(300)], 5],
+                        [[i for i in range(300)], 5],
+                    ],
+                    batches=10,
+                    max_chars=80,
+                    ismulti=True,
+                )
+            ),
+            "[[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 1 ... "
+            "(1310 characters truncated) ...  292, 293, 294, 295, 296, 297, "
+            "298, 299], 5], [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 1 ... "
+            "(1310 characters truncated) ...  292, 293, 294, 295, 296, 297, "
+            "298, 299], 5], [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 1 ... "
+            "(1310 characters truncated) ...  292, 293, 294, 295, 296, 297, "
+            "298, 299], 5]]",
+        )
+
     def test_log_large_parameter_single(self):
         import random