From d7107641c309e0b7db9b0876ac048dbb38316ba6 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 16 Dec 2022 12:56:21 -0500 Subject: [PATCH] make bind escape lookup extensible To accommodate for third party dialects with different character escaping needs regarding bound parameters, the system by which SQLAlchemy "escapes" (i.e., replaces with another character in its place) special characters in bound parameter names has been made extensible for third party dialects, using the :attr:`.SQLCompiler.bindname_escape_chars` dictionary which can be overridden at the class declaration level on any :class:`.SQLCompiler` subclass. As part of this change, also added the dot ``"."`` as a default "escaped" character. Fixes: #8994 Change-Id: I52fbbfa8c64497b123f57327113df3f022bd1419 --- doc/build/changelog/unreleased_20/8994.rst | 13 ++++ lib/sqlalchemy/dialects/oracle/cx_oracle.py | 35 ++++++---- lib/sqlalchemy/sql/compiler.py | 75 +++++++++++++++++++-- test/sql/test_compiler.py | 39 +++++++++++ 4 files changed, 144 insertions(+), 18 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/8994.rst diff --git a/doc/build/changelog/unreleased_20/8994.rst b/doc/build/changelog/unreleased_20/8994.rst new file mode 100644 index 0000000000..cd2a056fa7 --- /dev/null +++ b/doc/build/changelog/unreleased_20/8994.rst @@ -0,0 +1,13 @@ +.. change:: + :tags: bug, sql + :tickets: 8994 + + To accommodate for third party dialects with different character escaping + needs regarding bound parameters, the system by which SQLAlchemy "escapes" + (i.e., replaces with another character in its place) special characters in + bound parameter names has been made extensible for third party dialects, + using the :attr:`.SQLCompiler.bindname_escape_chars` dictionary which can + be overridden at the class declaration level on any :class:`.SQLCompiler` + subclass. As part of this change, also added the dot ``"."`` as a default + "escaped" character. + diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index 8f80aed656..c45aafae6b 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -445,15 +445,6 @@ from ...sql._typing import is_sql_compiler _CX_ORACLE_MAGIC_LOB_SIZE = 131072 -_ORACLE_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\]\.\/\? ]") - -# Oracle bind names can't start with digits or underscores. -# currently we rely upon Oracle-specific quoting of bind names in most cases. -# however for expanding params, the escape chars are used. -# see #8708 -_ORACLE_BIND_TRANSLATE_CHARS = dict(zip("%():[]./? ", "PAZCCCCCCCC")) - - class _OracleInteger(sqltypes.Integer): def get_dbapi_type(self, dbapi): # see https://github.com/oracle/python-cx_Oracle/issues/ @@ -694,6 +685,26 @@ class OracleCompiler_cx_oracle(OracleCompiler): _oracle_returning = False + # Oracle bind names can't start with digits or underscores. + # currently we rely upon Oracle-specific quoting of bind names in most + # cases. however for expanding params, the escape chars are used. + # see #8708 + bindname_escape_characters = util.immutabledict( + { + "%": "P", + "(": "A", + ")": "Z", + ":": "C", + ".": "C", + "[": "C", + "]": "C", + " ": "C", + "\\": "C", + "/": "C", + "?": "C", + } + ) + def bindparam_string(self, name, **kw): quote = getattr(name, "quote", None) if ( @@ -721,12 +732,12 @@ class OracleCompiler_cx_oracle(OracleCompiler): escaped_from = kw.get("escaped_from", None) if not escaped_from: - if _ORACLE_BIND_TRANSLATE_RE.search(name): + if self._bind_translate_re.search(name): # not quite the translate use case as we want to # also get a quick boolean if we even found # unusual characters in the name - new_name = _ORACLE_BIND_TRANSLATE_RE.sub( - lambda m: _ORACLE_BIND_TRANSLATE_CHARS[m.group(0)], + new_name = self._bind_translate_re.sub( + lambda m: self._bind_translate_chars[m.group(0)], name, ) if new_name[0].isdigit() or new_name[0] == "_": diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 66a294d106..596ca986f0 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -37,6 +37,7 @@ import typing from typing import Any from typing import Callable from typing import cast +from typing import ClassVar from typing import Dict from typing import FrozenSet from typing import Iterable @@ -46,6 +47,7 @@ from typing import MutableMapping from typing import NamedTuple from typing import NoReturn from typing import Optional +from typing import Pattern from typing import Sequence from typing import Set from typing import Tuple @@ -238,9 +240,6 @@ BIND_TEMPLATES = { } -_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\] ]") -_BIND_TRANSLATE_CHARS = dict(zip("%():[] ", "PAZC___")) - OPERATORS = { # binary operators.and_: " AND ", @@ -714,6 +713,14 @@ class Compiled: self._gen_time = perf_counter() + def __init_subclass__(cls) -> None: + cls._init_compiler_cls() + return super().__init_subclass__() + + @classmethod + def _init_compiler_cls(cls): + pass + def _execute_on_connection( self, connection, distilled_params, execution_options ): @@ -866,6 +873,52 @@ class SQLCompiler(Compiled): extract_map = EXTRACT_MAP + bindname_escape_characters: ClassVar[ + Mapping[str, str] + ] = util.immutabledict( + { + "%": "P", + "(": "A", + ")": "Z", + ":": "C", + ".": "_", + "[": "_", + "]": "_", + " ": "_", + } + ) + """A mapping (e.g. dict or similar) containing a lookup of + characters keyed to replacement characters which will be applied to all + 'bind names' used in SQL statements as a form of 'escaping'; the given + characters are replaced entirely with the 'replacement' character when + rendered in the SQL statement, and a similar translation is performed + on the incoming names used in parameter dictionaries passed to methods + like :meth:`_engine.Connection.execute`. + + This allows bound parameter names used in :func:`_sql.bindparam` and + other constructs to have any arbitrary characters present without any + concern for characters that aren't allowed at all on the target database. + + Third party dialects can establish their own dictionary here to replace the + default mapping, which will ensure that the particular characters in the + mapping will never appear in a bound parameter name. + + The dictionary is evaluated at **class creation time**, so cannot be + modified at runtime; it must be present on the class when the class + is first declared. + + Note that for dialects that have additional bound parameter rules such + as additional restrictions on leading characters, the + :meth:`_sql.SQLCompiler.bindparam_string` method may need to be augmented. + See the cx_Oracle compiler for an example of this. + + .. versionadded:: 2.0.0b5 + + """ + + _bind_translate_re: ClassVar[Pattern[str]] + _bind_translate_chars: ClassVar[Mapping[str, str]] + is_sql = True compound_keywords = COMPOUND_KEYWORDS @@ -1108,6 +1161,16 @@ class SQLCompiler(Compiled): f"{_pyformat_pattern.pattern}|{_post_compile_pattern.pattern}" ) + @classmethod + def _init_compiler_cls(cls): + cls._init_bind_translate() + + @classmethod + def _init_bind_translate(cls): + reg = re.escape("".join(cls.bindname_escape_characters)) + cls._bind_translate_re = re.compile(f"[{reg}]") + cls._bind_translate_chars = cls.bindname_escape_characters + def __init__( self, dialect: Dialect, @@ -3591,12 +3654,12 @@ class SQLCompiler(Compiled): if not escaped_from: - if _BIND_TRANSLATE_RE.search(name): + if self._bind_translate_re.search(name): # not quite the translate use case as we want to # also get a quick boolean if we even found # unusual characters in the name - new_name = _BIND_TRANSLATE_RE.sub( - lambda m: _BIND_TRANSLATE_CHARS[m.group(0)], + new_name = self._bind_translate_re.sub( + lambda m: self._bind_translate_chars[m.group(0)], name, ) escaped_from = name diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 39971fd766..2907c6e0e7 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -5152,6 +5152,45 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase): render_postcompile=True, ) + def test_bind_escape_extensibility(self): + """test #8994, extensibility of the bind escape character lookup. + + The main test for actual known characters passing through for bound + params is in + sqlalchemy.testing.suite.test_dialect.DifficultParametersTest. + + """ + dialect = default.DefaultDialect() + + class Compiler(compiler.StrSQLCompiler): + bindname_escape_characters = { + "%": "P", + # chars that need regex escaping + "(": "A", + ")": "Z", + "*": "S", + "+": "L", + # completely random "normie" character + "8": "E", + ":": "C", + # left bracket is not escaped, right bracket is + "]": "_", + " ": "_", + } + + dialect.statement_compiler = Compiler + + self.assert_compile( + select( + bindparam("number8ight"), + bindparam("plus+sign"), + bindparam("par(en)s and [brackets]"), + ), + "SELECT :numberEight AS anon_1, :plusLsign AS anon_2, " + ":parAenZs_and_[brackets_ AS anon_3", + dialect=dialect, + ) + class CompileUXTest(fixtures.TestBase): """tests focused on calling stmt.compile() directly, user cases""" -- 2.47.2