From 21ccdd0960ab6c928db24891398362b4d1037d23 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 19 Oct 2020 10:19:29 -0400 Subject: [PATCH] Ensure no compiler visit method tries to access .statement Fixed structural compiler issue where some constructs such as MySQL / PostgreSQL "on conflict / on duplicate key" would rely upon the state of the :class:`_sql.Compiler` object being fixed against their statement as the top level statement, which would fail in cases where those statements are branched from a different context, such as a DDL construct linked to a SQL statement. Fixes: #5656 Change-Id: I568bf40adc7edcf72ea6c7fd6eb9d07790de189e (cherry picked from commit dfcd117afca36ea3f5f574f9bef6036dae609de8) --- doc/build/changelog/unreleased_13/5656.rst | 11 +++++ lib/sqlalchemy/dialects/mysql/base.py | 12 ++--- lib/sqlalchemy/dialects/postgresql/base.py | 2 +- lib/sqlalchemy/sql/compiler.py | 38 +++++++++++++++ lib/sqlalchemy/testing/assertions.py | 56 +++++++++++++++++++++- test/sql/test_compiler.py | 49 +++++++++++++++++++ 6 files changed, 160 insertions(+), 8 deletions(-) create mode 100644 doc/build/changelog/unreleased_13/5656.rst diff --git a/doc/build/changelog/unreleased_13/5656.rst b/doc/build/changelog/unreleased_13/5656.rst new file mode 100644 index 0000000000..cdec608424 --- /dev/null +++ b/doc/build/changelog/unreleased_13/5656.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, sql + :tickets: 5656 + + Fixed structural compiler issue where some constructs such as MySQL / + PostgreSQL "on conflict / on duplicate key" would rely upon the state of + the :class:`_sql.Compiler` object being fixed against their statement as + the top level statement, which would fail in cases where those statements + are branched from a different context, such as a DDL construct linked to a + SQL statement. + diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index e7bf7c9b34..4505a83e6c 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1383,6 +1383,8 @@ class MySQLCompiler(compiler.SQLCompiler): return self._render_json_extract_from_binary(binary, operator, **kw) def visit_on_duplicate_key_update(self, on_duplicate, **kw): + statement = self.current_executable + if on_duplicate._parameter_ordering: parameter_ordering = [ elements._column_as_key(key) @@ -1390,14 +1392,12 @@ class MySQLCompiler(compiler.SQLCompiler): ] ordered_keys = set(parameter_ordering) cols = [ - self.statement.table.c[key] + statement.table.c[key] for key in parameter_ordering - if key in self.statement.table.c - ] + [ - c for c in self.statement.table.c if c.key not in ordered_keys - ] + if key in statement.table.c + ] + [c for c in statement.table.c if c.key not in ordered_keys] else: - cols = self.statement.table.c + cols = statement.table.c clauses = [] # traverses through all table columns to preserve table column order diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 1295f76597..82c05ab840 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1951,7 +1951,7 @@ class PGCompiler(compiler.SQLCompiler): "Additional column names not matching " "any column keys in table '%s': %s" % ( - self.statement.table.name, + self.current_executable.table.name, (", ".join("'%s'" % c for c in set_parameters)), ) ) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 3864af45b3..0489881d23 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -599,6 +599,44 @@ class SQLCompiler(Compiled): if self.positional and self._numeric_binds: self._apply_numbered_params() + @property + def current_executable(self): + """Return the current 'executable' that is being compiled. + + This is currently the :class:`_sql.Select`, :class:`_sql.Insert`, + :class:`_sql.Update`, :class:`_sql.Delete`, + :class:`_sql.CompoundSelect` object that is being compiled. + Specifically it's assigned to the ``self.stack`` list of elements. + + When a statement like the above is being compiled, it normally + is also assigned to the ``.statement`` attribute of the + :class:`_sql.Compiler` object. However, all SQL constructs are + ultimately nestable, and this attribute should never be consulted + by a ``visit_`` method, as it is not guaranteed to be assigned + nor guaranteed to correspond to the current statement being compiled. + + .. versionadded:: 1.3.21 + + For compatibility with previous versions, use the following + recipe:: + + statement = getattr(self, "current_executable", False) + if statement is False: + statement = self.stack[-1]["selectable"] + + For versions 1.4 and above, ensure only .current_executable + is used; the format of "self.stack" may change. + + + """ + try: + return self.stack[-1]["selectable"] + except IndexError as ie: + util.raise_( + IndexError("Compiler does not have a stack entry"), + replace_context=ie, + ) + @property def prefetch(self): return list(self.insert_prefetch + self.update_prefetch) diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index 2c9baf36f9..52912d8dec 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -21,6 +21,7 @@ from .util import fail from .. import exc as sa_exc from .. import pool from .. import schema +from .. import sql from .. import types as sqltypes from .. import util from ..engine import default @@ -444,7 +445,60 @@ class AssertsCompiledSQL(object): if compile_kwargs: kw["compile_kwargs"] = compile_kwargs - c = clause.compile(dialect=dialect, **kw) + class DontAccess(object): + def __getattribute__(self, key): + raise NotImplementedError( + "compiler accessed .statement; use " + "compiler.current_executable" + ) + + class CheckCompilerAccess(object): + def __init__(self, test_statement): + self.test_statement = test_statement + self.supports_execution = getattr( + test_statement, "supports_execution", False + ) + if self.supports_execution: + self._execution_options = test_statement._execution_options + + if isinstance( + test_statement, (sql.Insert, sql.Update, sql.Delete) + ): + self._returning = test_statement._returning + if isinstance(test_statement, (sql.Insert, sql.Update)): + self.inline = test_statement.inline + self._return_defaults = test_statement._return_defaults + + def _default_dialect(self): + return self.test_statement._default_dialect() + + def compile(self, dialect, **kw): + return self.test_statement.compile.__func__( + self, dialect=dialect, **kw + ) + + def _compiler(self, dialect, **kw): + return self.test_statement._compiler.__func__( + self, dialect, **kw + ) + + def _compiler_dispatch(self, compiler, **kwargs): + if hasattr(compiler, "statement"): + with mock.patch.object( + compiler, "statement", DontAccess() + ): + return self.test_statement._compiler_dispatch( + compiler, **kwargs + ) + else: + return self.test_statement._compiler_dispatch( + compiler, **kwargs + ) + + # no construct can assume it's the "top level" construct in all cases + # as anything can be nested. ensure constructs don't assume they + # are the "self.statement" element + c = CheckCompilerAccess(clause).compile(dialect=dialect, **kw) param_str = repr(getattr(c, "params", {})) diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 20ab18f424..772dff109d 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -74,6 +74,7 @@ from sqlalchemy.sql import compiler from sqlalchemy.sql import label from sqlalchemy.sql import table from sqlalchemy.sql.expression import _literal_as_text +from sqlalchemy.sql.expression import ClauseElement from sqlalchemy.sql.expression import ClauseList from sqlalchemy.sql.expression import HasPrefixes from sqlalchemy.testing import assert_raises @@ -83,6 +84,7 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import eq_ignore_whitespace from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing import mock from sqlalchemy.util import u @@ -146,6 +148,53 @@ keyed = Table( ) +class TestCompilerFixture(fixtures.TestBase, AssertsCompiledSQL): + def test_dont_access_statement(self): + def visit_foobar(self, element, **kw): + self.statement.table + + class Foobar(ClauseElement): + __visit_name__ = "foobar" + + with mock.patch.object( + testing.db.dialect.statement_compiler, + "visit_foobar", + visit_foobar, + create=True, + ): + assert_raises_message( + NotImplementedError, + "compiler accessed .statement; use " + "compiler.current_executable", + self.assert_compile, + Foobar(), + "", + ) + + def test_no_stack(self): + def visit_foobar(self, element, **kw): + self.current_executable.table + + class Foobar(ClauseElement): + __visit_name__ = "foobar" + + with mock.patch.object( + testing.db.dialect.statement_compiler, + "visit_foobar", + visit_foobar, + create=True, + ): + compiler = testing.db.dialect.statement_compiler( + testing.db.dialect, None + ) + assert_raises_message( + IndexError, + "Compiler does not have a stack entry", + compiler.process, + Foobar(), + ) + + class SelectTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" -- 2.47.3