]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Ensure escaping of percent signs in columns, parameters
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 17 Oct 2020 15:39:56 +0000 (11:39 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 17 Oct 2020 17:15:57 +0000 (13:15 -0400)
Improved support for column names that contain percent signs in the string,
including repaired issues involving anoymous labels that also embedded a
column name with a percent sign in it, as well as re-established support
for bound parameter names with percent signs embedded on the psycopg2
dialect, using a late-escaping process similar to that used by the
cx_Oracle dialect.

* Added new constructor for _anonymous_label() that ensures incoming
  string tokens based on column or table names will have percent
  signs escaped; abstracts away the format of the label.

* generalized cx_Oracle's quoted_bind_names facility into the compiler
  itself, and leveraged this for the psycopg2 dialect's issue with
  percent signs in names as well.  the parameter substitution is now
  integrated with compiler.construct_parameters() as well as the
  recently reworked set_input_sizes(), reducing verbosity in the
  cx_Oracle dialect.

Fixes: #5653
Change-Id: Ia2ad13ea68b4b0558d410026e5a33f5cb3fbab2c

doc/build/changelog/unreleased_14/5653.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/oracle/cx_oracle.py
lib/sqlalchemy/dialects/postgresql/psycopg2.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/selectable.py
test/requirements.py
test/sql/test_compiler.py
test/sql/test_selectable.py

diff --git a/doc/build/changelog/unreleased_14/5653.rst b/doc/build/changelog/unreleased_14/5653.rst
new file mode 100644 (file)
index 0000000..0722843
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, sql, postgresql
+    :tickets: 5653
+
+    Improved support for column names that contain percent signs in the string,
+    including repaired issues involving anoymous labels that also embedded a
+    column name with a percent sign in it, as well as re-established support
+    for bound parameter names with percent signs embedded on the psycopg2
+    dialect, using a late-escaping process similar to that used by the
+    cx_Oracle dialect.
+
index 655614769be9b45da922e38aae1d815e4554d1f0..1f1f7501bd378def0aa9d24d9994d1040637c410 100644 (file)
@@ -849,7 +849,6 @@ class OracleCompiler(compiler.SQLCompiler):
 
     def __init__(self, *args, **kwargs):
         self.__wheres = {}
-        self._quoted_bind_names = {}
         super(OracleCompiler, self).__init__(*args, **kwargs)
 
     def visit_mod_binary(self, binary, operator, **kw):
index 7bde19090d18209ffb98d2718b5026d25b77ca9f..8eb9f8b3cfaa86991285bc3e76f3082ec7243919 100644 (file)
@@ -598,28 +598,20 @@ class OracleCompiler_cx_oracle(OracleCompiler):
             # need quoting :).    names that include illegal characters
             # won't work however.
             quoted_name = '"%s"' % name
-            self._quoted_bind_names[name] = quoted_name
-            return OracleCompiler.bindparam_string(self, quoted_name, **kw)
-        else:
-            return OracleCompiler.bindparam_string(self, name, **kw)
+            kw["escaped_from"] = name
+            name = quoted_name
+
+        return OracleCompiler.bindparam_string(self, name, **kw)
 
 
 class OracleExecutionContext_cx_oracle(OracleExecutionContext):
     out_parameters = None
 
-    def _setup_quoted_bind_names(self):
-        quoted_bind_names = self.compiled._quoted_bind_names
-        if quoted_bind_names:
-            for param in self.parameters:
-                for fromname, toname in quoted_bind_names.items():
-                    param[toname] = param[fromname]
-                    del param[fromname]
-
     def _generate_out_parameter_vars(self):
         # check for has_out_parameters or RETURNING, create cx_Oracle.var
         # objects if so
         if self.compiled.returning or self.compiled.has_out_parameters:
-            quoted_bind_names = self.compiled._quoted_bind_names
+            quoted_bind_names = self.compiled.escaped_bind_names
             for bindparam in self.compiled.binds.values():
                 if bindparam.isoutparam:
                     name = self.compiled.bind_names[bindparam]
@@ -684,9 +676,6 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
 
         self.out_parameters = {}
 
-        if self.compiled._quoted_bind_names:
-            self._setup_quoted_bind_names()
-
         self._generate_out_parameter_vars()
 
         self._generate_cursor_outputtype_handler()
@@ -1184,12 +1173,6 @@ class OracleDialect_cx_oracle(OracleDialect):
                 for key, dbtype, sqltype in list_of_tuples
                 if dbtype
             )
-            if context and context.compiled:
-                quoted_bind_names = context.compiled._quoted_bind_names
-                collection = (
-                    (quoted_bind_names.get(key, key), dbtype)
-                    for key, dbtype in collection
-                )
 
             if not self.supports_unicode_binds:
                 # oracle 8 only
index 2446604ba552984766dd0a71454baf5ad42a11c0..72c36b4a86917e96b53a8fcd1dee5c8d511c8c46 100644 (file)
@@ -644,7 +644,14 @@ class PGExecutionContext_psycopg2(PGExecutionContext):
 
 
 class PGCompiler_psycopg2(PGCompiler):
-    pass
+    def bindparam_string(self, name, **kw):
+        if "%" in name and not kw.get("post_compile", False):
+            # psycopg2 will not allow a percent sign in a
+            # pyformat parameter name even if it is doubled
+            kw["escaped_from"] = name
+            name = name.replace("%", "P")
+
+        return PGCompiler.bindparam_string(self, name, **kw)
 
 
 class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer):
index d63cb4addd838b2752b513c17a30bd197fd7f802..7f92271e97c0bd2ffb1af3cb44a184df48a558b5 100644 (file)
@@ -1507,6 +1507,10 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
                 inputsizes, self.cursor, self.statement, self.parameters, self
             )
 
+        has_escaped_names = bool(self.compiled.escaped_bind_names)
+        if has_escaped_names:
+            escaped_bind_names = self.compiled.escaped_bind_names
+
         if self.dialect.positional:
             items = [
                 (key, self.compiled.binds[key])
@@ -1529,7 +1533,11 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
                     dbtypes = inputsizes[bindparam]
                     generic_inputsizes.extend(
                         (
-                            paramname,
+                            (
+                                escaped_bind_names.get(paramname, paramname)
+                                if has_escaped_names
+                                else paramname
+                            ),
                             dbtypes[idx % num],
                             bindparam.type.types[idx % num],
                         )
@@ -1540,12 +1548,29 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
                 else:
                     dbtype = inputsizes.get(bindparam, None)
                     generic_inputsizes.extend(
-                        (paramname, dbtype, bindparam.type)
+                        (
+                            (
+                                escaped_bind_names.get(paramname, paramname)
+                                if has_escaped_names
+                                else paramname
+                            ),
+                            dbtype,
+                            bindparam.type,
+                        )
                         for paramname in self._expanded_parameters[key]
                     )
             else:
                 dbtype = inputsizes.get(bindparam, None)
-                generic_inputsizes.append((key, dbtype, bindparam.type))
+
+                escaped_name = (
+                    escaped_bind_names.get(key, key)
+                    if has_escaped_names
+                    else key
+                )
+
+                generic_inputsizes.append(
+                    (escaped_name, dbtype, bindparam.type)
+                )
         try:
             self.dialect.do_set_input_sizes(
                 self.cursor, generic_inputsizes, self
index 23cd778d04567e74ccb658e914e02b3ab45efd39..8718e15ea9b64c1f7c742dcc08648df0696ccd15 100644 (file)
@@ -663,6 +663,12 @@ class SQLCompiler(Compiled):
 
     """
 
+    escaped_bind_names = util.EMPTY_DICT
+    """Late escaping of bound parameter names that has to be converted
+    to the original name when looking in the parameter dictionary.
+
+    """
+
     has_out_parameters = False
     """if True, there are bindparam() objects that have the isoutparam
     flag set."""
@@ -879,6 +885,8 @@ class SQLCompiler(Compiled):
     ):
         """return a dictionary of bind parameter keys and values"""
 
+        has_escaped_names = bool(self.escaped_bind_names)
+
         if extracted_parameters:
             # related the bound parameters collected in the original cache key
             # to those collected in the incoming cache key.  They will not have
@@ -908,10 +916,16 @@ class SQLCompiler(Compiled):
         if params:
             pd = {}
             for bindparam, name in self.bind_names.items():
+                escaped_name = (
+                    self.escaped_bind_names.get(name, name)
+                    if has_escaped_names
+                    else name
+                )
+
                 if bindparam.key in params:
-                    pd[name] = params[bindparam.key]
+                    pd[escaped_name] = params[bindparam.key]
                 elif name in params:
-                    pd[name] = params[name]
+                    pd[escaped_name] = params[name]
 
                 elif _check and bindparam.required:
                     if _group_number:
@@ -936,13 +950,19 @@ class SQLCompiler(Compiled):
                         value_param = bindparam
 
                     if bindparam.callable:
-                        pd[name] = value_param.effective_value
+                        pd[escaped_name] = value_param.effective_value
                     else:
-                        pd[name] = value_param.value
+                        pd[escaped_name] = value_param.value
             return pd
         else:
             pd = {}
             for bindparam, name in self.bind_names.items():
+                escaped_name = (
+                    self.escaped_bind_names.get(name, name)
+                    if has_escaped_names
+                    else name
+                )
+
                 if _check and bindparam.required:
                     if _group_number:
                         raise exc.InvalidRequestError(
@@ -964,9 +984,9 @@ class SQLCompiler(Compiled):
                     value_param = bindparam
 
                 if bindparam.callable:
-                    pd[name] = value_param.effective_value
+                    pd[escaped_name] = value_param.effective_value
                 else:
-                    pd[name] = value_param.value
+                    pd[escaped_name] = value_param.value
             return pd
 
     @util.memoized_instancemethod
@@ -2316,6 +2336,7 @@ class SQLCompiler(Compiled):
         positional_names=None,
         post_compile=False,
         expanding=False,
+        escaped_from=None,
         **kw
     ):
         if self.positional:
@@ -2323,6 +2344,11 @@ class SQLCompiler(Compiled):
                 positional_names.append(name)
             else:
                 self.positiontup.append(name)
+
+        if escaped_from:
+            if not self.escaped_bind_names:
+                self.escaped_bind_names = {}
+            self.escaped_bind_names[escaped_from] = name
         if post_compile:
             return "[POSTCOMPILE_%s]" % name
         else:
index 5fb28f1d1a6e07e22fee60385272deefbcf3ec27..00e28ac2072ea8734a7df632be70f59f764e019f 100644 (file)
@@ -957,9 +957,11 @@ class ColumnElement(
         # as the identifier, because a column and its annotated version are
         # the same thing in a SQL statement
         if isinstance(seed, _anonymous_label):
-            return _anonymous_label("%s%%(%d %s)s" % (seed, hash(self), ""))
+            return _anonymous_label.safe_construct(
+                hash(self), "", enclosing_label=seed
+            )
 
-        return _anonymous_label("%%(%d %s)s" % (hash(self), seed or "anon"))
+        return _anonymous_label.safe_construct(hash(self), seed or "anon")
 
     @util.memoized_property
     def anon_label(self):
@@ -1324,21 +1326,17 @@ class BindParameter(roles.InElementRole, ColumnElement):
             key = quoted_name(key, quote)
 
         if unique:
-            self.key = _anonymous_label(
-                "%%(%d %s)s"
-                % (
-                    id(self),
-                    re.sub(r"[%\(\) \$]+", "_", key).strip("_")
-                    if key is not None
-                    and not isinstance(key, _anonymous_label)
-                    else "param",
-                )
+            self.key = _anonymous_label.safe_construct(
+                id(self),
+                re.sub(r"[%\(\) \$]+", "_", key).strip("_")
+                if key is not None and not isinstance(key, _anonymous_label)
+                else "param",
             )
             self._key_is_anon = True
         elif key:
             self.key = key
         else:
-            self.key = _anonymous_label("%%(%d param)s" % id(self))
+            self.key = _anonymous_label.safe_construct(id(self), "param")
             self._key_is_anon = True
 
         # identifying key that won't change across
@@ -1407,8 +1405,8 @@ class BindParameter(roles.InElementRole, ColumnElement):
     def _clone(self, maintain_key=False):
         c = ClauseElement._clone(self)
         if not maintain_key and self.unique:
-            c.key = _anonymous_label(
-                "%%(%d %s)s" % (id(c), c._orig_key or "param")
+            c.key = _anonymous_label.safe_construct(
+                id(c), c._orig_key or "param"
             )
         return c
 
@@ -1442,8 +1440,8 @@ class BindParameter(roles.InElementRole, ColumnElement):
     def _convert_to_unique(self):
         if not self.unique:
             self.unique = True
-            self.key = _anonymous_label(
-                "%%(%d %s)s" % (id(self), self._orig_key or "param")
+            self.key = _anonymous_label.safe_construct(
+                id(self), self._orig_key or "param"
             )
 
     def __getstate__(self):
@@ -1459,8 +1457,8 @@ class BindParameter(roles.InElementRole, ColumnElement):
 
     def __setstate__(self, state):
         if state.get("unique", False):
-            state["key"] = _anonymous_label(
-                "%%(%d %s)s" % (id(self), state.get("_orig_key", "param"))
+            state["key"] = _anonymous_label.safe_construct(
+                id(self), state.get("_orig_key", "param")
             )
         self.__dict__.update(state)
 
@@ -4188,8 +4186,8 @@ class Label(roles.LabeledColumnExprRole, ColumnElement):
             self.name = name
             self._resolve_label = self.name
         else:
-            self.name = _anonymous_label(
-                "%%(%d %s)s" % (id(self), getattr(element, "name", "anon"))
+            self.name = _anonymous_label.safe_construct(
+                id(self), getattr(element, "name", "anon")
             )
 
         self.key = self._label = self._key_label = self.name
@@ -4247,9 +4245,8 @@ class Label(roles.LabeledColumnExprRole, ColumnElement):
     def _copy_internals(self, clone=_clone, anonymize_labels=False, **kw):
         self._element = clone(self._element, **kw)
         if anonymize_labels:
-            self.name = self._resolve_label = _anonymous_label(
-                "%%(%d %s)s"
-                % (id(self), getattr(self.element, "name", "anon"))
+            self.name = self._resolve_label = _anonymous_label.safe_construct(
+                id(self), getattr(self.element, "name", "anon")
             )
             self.key = self._label = self._key_label = self.name
 
@@ -4890,17 +4887,39 @@ class _anonymous_label(_truncated_label):
 
     __slots__ = ()
 
+    @classmethod
+    def safe_construct(cls, seed, body, enclosing_label=None):
+        # type: (int, str, Optional[_anonymous_label]) -> _anonymous_label
+
+        label = "%%(%d %s)s" % (seed, body.replace("%", "%%"))
+        if enclosing_label:
+            label = "%s%s" % (enclosing_label, label)
+
+        return _anonymous_label(label)
+
     def __add__(self, other):
+        if "%" in other and not isinstance(other, _anonymous_label):
+            other = util.text_type(other).replace("%", "%%")
+        else:
+            other = util.text_type(other)
+
         return _anonymous_label(
             quoted_name(
-                util.text_type.__add__(self, util.text_type(other)), self.quote
+                util.text_type.__add__(self, other),
+                self.quote,
             )
         )
 
     def __radd__(self, other):
+        if "%" in other and not isinstance(other, _anonymous_label):
+            other = util.text_type(other).replace("%", "%%")
+        else:
+            other = util.text_type(other)
+
         return _anonymous_label(
             quoted_name(
-                util.text_type.__add__(util.text_type(other), self), self.quote
+                util.text_type.__add__(other, self),
+                self.quote,
             )
         )
 
index 0e88a899910b93bc79d2bf56fa74399dd61bd613..fd88324007aea086b866d699208b0c62221e5f3d 100644 (file)
@@ -1432,7 +1432,7 @@ class AliasedReturnsRows(NoInit, FromClause):
                 name = getattr(selectable, "name", None)
                 if isinstance(name, _anonymous_label):
                     name = None
-            name = _anonymous_label("%%(%d %s)s" % (id(self), name or "anon"))
+            name = _anonymous_label.safe_construct(id(self), name or "anon")
         self.name = name
 
     def _refresh_for_new_column(self, column):
index 355f8910e39f470553435fa35a0f03005e9d496d..4486973f6127b69b1d1d8b26505a989041fe5eb0 100644 (file)
@@ -1343,26 +1343,7 @@ class DefaultRequirements(SuiteRequirements):
 
     @property
     def percent_schema_names(self):
-        return skip_if(
-            [
-                (
-                    "+psycopg2",
-                    None,
-                    None,
-                    "psycopg2 2.4 no longer accepts percent "
-                    "sign in bind placeholders",
-                ),
-                (
-                    "+psycopg2cffi",
-                    None,
-                    None,
-                    "psycopg2cffi does not accept percent signs in "
-                    "bind placeholders",
-                ),
-                ("mysql", None, None, "executemany() doesn't work here"),
-                ("mariadb", None, None, "executemany() doesn't work here"),
-            ]
-        )
+        return exclusions.open()
 
     @property
     def order_by_label_with_expression(self):
index ffabf9379df5506b753bc1fbcbdce1c609e1a412..ef2f75b2d6caeb076d2c9d2b7500039abe2fbd02 100644 (file)
@@ -941,6 +941,24 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
             "AS z FROM keyed) AS anon_2) AS anon_1",
         )
 
+    @testing.combinations("per cent", "per % cent", "%percent")
+    def test_percent_names_collide_with_anonymizing(self, name):
+        table1 = table("t1", column(name))
+
+        jj = select(table1.c[name]).subquery()
+        jjj = join(table1, jj, table1.c[name] == jj.c[name])
+
+        j2 = jjj.select().apply_labels().subquery("foo")
+
+        self.assert_compile(
+            j2.select(),
+            'SELECT foo."t1_%(name)s", foo."anon_1_%(name)s" FROM '
+            '(SELECT t1."%(name)s" AS "t1_%(name)s", anon_1."%(name)s" '
+            'AS "anon_1_%(name)s" FROM t1 JOIN (SELECT t1."%(name)s" AS '
+            '"%(name)s" FROM t1) AS anon_1 ON t1."%(name)s" = '
+            'anon_1."%(name)s") AS foo' % {"name": name},
+        )
+
     def test_exists(self):
         s = select(table1.c.myid).where(table1.c.myid == 5)
 
index b98fbd3d07d8fbb152a06a6452548b66b37648e0..c75e8886de4d0d585d10112043b4bf742f5a184c 100644 (file)
@@ -740,7 +740,7 @@ class SelectableTest(
         assert u2.corresponding_column(s1.selected_columns.col1) is u2.c.col1
         assert u2.corresponding_column(s2.selected_columns.col1) is u2.c.col1
 
-    def test_foo(self):
+    def test_union_alias_misc(self):
         s1 = select(table1.c.col1, table1.c.col2)
         s2 = select(table1.c.col2, table1.c.col1)