]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
make bind escape lookup extensible
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 16 Dec 2022 17:56:21 +0000 (12:56 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 16 Dec 2022 18:37:40 +0000 (13:37 -0500)
To accommodate for third party dialects with different character escaping
needs regarding bound parameters, the system by which SQLAlchemy "escapes"
(i.e., replaces with another character in its place) special characters in
bound parameter names has been made extensible for third party dialects,
using the :attr:`.SQLCompiler.bindname_escape_chars` dictionary which can
be overridden at the class declaration level on any :class:`.SQLCompiler`
subclass. As part of this change, also added the dot ``"."`` as a default
"escaped" character.

Fixes: #8994
Change-Id: I52fbbfa8c64497b123f57327113df3f022bd1419

doc/build/changelog/unreleased_20/8994.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/oracle/cx_oracle.py
lib/sqlalchemy/sql/compiler.py
test/sql/test_compiler.py

diff --git a/doc/build/changelog/unreleased_20/8994.rst b/doc/build/changelog/unreleased_20/8994.rst
new file mode 100644 (file)
index 0000000..cd2a056
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 8994
+
+    To accommodate for third party dialects with different character escaping
+    needs regarding bound parameters, the system by which SQLAlchemy "escapes"
+    (i.e., replaces with another character in its place) special characters in
+    bound parameter names has been made extensible for third party dialects,
+    using the :attr:`.SQLCompiler.bindname_escape_chars` dictionary which can
+    be overridden at the class declaration level on any :class:`.SQLCompiler`
+    subclass. As part of this change, also added the dot ``"."`` as a default
+    "escaped" character.
+
index 8f80aed656d365b80314e33f2ef11819a1a6f62f..c45aafae6be1c4da0f9d98b5925c57797560fd9a 100644 (file)
@@ -445,15 +445,6 @@ from ...sql._typing import is_sql_compiler
 _CX_ORACLE_MAGIC_LOB_SIZE = 131072
 
 
-_ORACLE_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\]\.\/\? ]")
-
-# Oracle bind names can't start with digits or underscores.
-# currently we rely upon Oracle-specific quoting of bind names in most cases.
-# however for expanding params, the escape chars are used.
-# see #8708
-_ORACLE_BIND_TRANSLATE_CHARS = dict(zip("%():[]./? ", "PAZCCCCCCCC"))
-
-
 class _OracleInteger(sqltypes.Integer):
     def get_dbapi_type(self, dbapi):
         # see https://github.com/oracle/python-cx_Oracle/issues/
@@ -694,6 +685,26 @@ class OracleCompiler_cx_oracle(OracleCompiler):
 
     _oracle_returning = False
 
+    # Oracle bind names can't start with digits or underscores.
+    # currently we rely upon Oracle-specific quoting of bind names in most
+    # cases.  however for expanding params, the escape chars are used.
+    # see #8708
+    bindname_escape_characters = util.immutabledict(
+        {
+            "%": "P",
+            "(": "A",
+            ")": "Z",
+            ":": "C",
+            ".": "C",
+            "[": "C",
+            "]": "C",
+            " ": "C",
+            "\\": "C",
+            "/": "C",
+            "?": "C",
+        }
+    )
+
     def bindparam_string(self, name, **kw):
         quote = getattr(name, "quote", None)
         if (
@@ -721,12 +732,12 @@ class OracleCompiler_cx_oracle(OracleCompiler):
         escaped_from = kw.get("escaped_from", None)
         if not escaped_from:
 
-            if _ORACLE_BIND_TRANSLATE_RE.search(name):
+            if self._bind_translate_re.search(name):
                 # not quite the translate use case as we want to
                 # also get a quick boolean if we even found
                 # unusual characters in the name
-                new_name = _ORACLE_BIND_TRANSLATE_RE.sub(
-                    lambda m: _ORACLE_BIND_TRANSLATE_CHARS[m.group(0)],
+                new_name = self._bind_translate_re.sub(
+                    lambda m: self._bind_translate_chars[m.group(0)],
                     name,
                 )
                 if new_name[0].isdigit() or new_name[0] == "_":
index 66a294d1061bf98cc349a3313074e2e91a597697..596ca986f0dba0c594e916a58879e801334edacd 100644 (file)
@@ -37,6 +37,7 @@ import typing
 from typing import Any
 from typing import Callable
 from typing import cast
+from typing import ClassVar
 from typing import Dict
 from typing import FrozenSet
 from typing import Iterable
@@ -46,6 +47,7 @@ from typing import MutableMapping
 from typing import NamedTuple
 from typing import NoReturn
 from typing import Optional
+from typing import Pattern
 from typing import Sequence
 from typing import Set
 from typing import Tuple
@@ -238,9 +240,6 @@ BIND_TEMPLATES = {
 }
 
 
-_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\] ]")
-_BIND_TRANSLATE_CHARS = dict(zip("%():[] ", "PAZC___"))
-
 OPERATORS = {
     # binary
     operators.and_: " AND ",
@@ -714,6 +713,14 @@ class Compiled:
 
         self._gen_time = perf_counter()
 
+    def __init_subclass__(cls) -> None:
+        cls._init_compiler_cls()
+        return super().__init_subclass__()
+
+    @classmethod
+    def _init_compiler_cls(cls):
+        pass
+
     def _execute_on_connection(
         self, connection, distilled_params, execution_options
     ):
@@ -866,6 +873,52 @@ class SQLCompiler(Compiled):
 
     extract_map = EXTRACT_MAP
 
+    bindname_escape_characters: ClassVar[
+        Mapping[str, str]
+    ] = util.immutabledict(
+        {
+            "%": "P",
+            "(": "A",
+            ")": "Z",
+            ":": "C",
+            ".": "_",
+            "[": "_",
+            "]": "_",
+            " ": "_",
+        }
+    )
+    """A mapping (e.g. dict or similar) containing a lookup of
+    characters keyed to replacement characters which will be applied to all
+    'bind names' used in SQL statements as a form of 'escaping'; the given
+    characters are replaced entirely with the 'replacement' character when
+    rendered in the SQL statement, and a similar translation is performed
+    on the incoming names used in parameter dictionaries passed to methods
+    like :meth:`_engine.Connection.execute`.
+
+    This allows bound parameter names used in :func:`_sql.bindparam` and
+    other constructs to have any arbitrary characters present without any
+    concern for characters that aren't allowed at all on the target database.
+
+    Third party dialects can establish their own dictionary here to replace the
+    default mapping, which will ensure that the particular characters in the
+    mapping will never appear in a bound parameter name.
+
+    The dictionary is evaluated at **class creation time**, so cannot be
+    modified at runtime; it must be present on the class when the class
+    is first declared.
+
+    Note that for dialects that have additional bound parameter rules such
+    as additional restrictions on leading characters, the
+    :meth:`_sql.SQLCompiler.bindparam_string` method may need to be augmented.
+    See the cx_Oracle compiler for an example of this.
+
+    .. versionadded:: 2.0.0b5
+
+    """
+
+    _bind_translate_re: ClassVar[Pattern[str]]
+    _bind_translate_chars: ClassVar[Mapping[str, str]]
+
     is_sql = True
 
     compound_keywords = COMPOUND_KEYWORDS
@@ -1108,6 +1161,16 @@ class SQLCompiler(Compiled):
         f"{_pyformat_pattern.pattern}|{_post_compile_pattern.pattern}"
     )
 
+    @classmethod
+    def _init_compiler_cls(cls):
+        cls._init_bind_translate()
+
+    @classmethod
+    def _init_bind_translate(cls):
+        reg = re.escape("".join(cls.bindname_escape_characters))
+        cls._bind_translate_re = re.compile(f"[{reg}]")
+        cls._bind_translate_chars = cls.bindname_escape_characters
+
     def __init__(
         self,
         dialect: Dialect,
@@ -3591,12 +3654,12 @@ class SQLCompiler(Compiled):
 
         if not escaped_from:
 
-            if _BIND_TRANSLATE_RE.search(name):
+            if self._bind_translate_re.search(name):
                 # not quite the translate use case as we want to
                 # also get a quick boolean if we even found
                 # unusual characters in the name
-                new_name = _BIND_TRANSLATE_RE.sub(
-                    lambda m: _BIND_TRANSLATE_CHARS[m.group(0)],
+                new_name = self._bind_translate_re.sub(
+                    lambda m: self._bind_translate_chars[m.group(0)],
                     name,
                 )
                 escaped_from = name
index 39971fd76669616e970bc6340c3d7a097137495f..2907c6e0e7794cc816b715964edfb821358a971f 100644 (file)
@@ -5152,6 +5152,45 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase):
             render_postcompile=True,
         )
 
+    def test_bind_escape_extensibility(self):
+        """test #8994, extensibility of the bind escape character lookup.
+
+        The main test for actual known characters passing through for bound
+        params is in
+        sqlalchemy.testing.suite.test_dialect.DifficultParametersTest.
+
+        """
+        dialect = default.DefaultDialect()
+
+        class Compiler(compiler.StrSQLCompiler):
+            bindname_escape_characters = {
+                "%": "P",
+                # chars that need regex escaping
+                "(": "A",
+                ")": "Z",
+                "*": "S",
+                "+": "L",
+                # completely random "normie" character
+                "8": "E",
+                ":": "C",
+                # left bracket is not escaped, right bracket is
+                "]": "_",
+                " ": "_",
+            }
+
+        dialect.statement_compiler = Compiler
+
+        self.assert_compile(
+            select(
+                bindparam("number8ight"),
+                bindparam("plus+sign"),
+                bindparam("par(en)s and [brackets]"),
+            ),
+            "SELECT :numberEight AS anon_1, :plusLsign AS anon_2, "
+            ":parAenZs_and_[brackets_ AS anon_3",
+            dialect=dialect,
+        )
+
 
 class CompileUXTest(fixtures.TestBase):
     """tests focused on calling stmt.compile() directly, user cases"""