]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Ensure no compiler visit method tries to access .statement
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 19 Oct 2020 14:19:29 +0000 (10:19 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 19 Oct 2020 16:17:03 +0000 (12:17 -0400)
Fixed structural compiler issue where some constructs such as MySQL /
PostgreSQL "on conflict / on duplicate key" would rely upon the state of
the :class:`_sql.Compiler` object being fixed against their statement as
the top level statement, which would fail in cases where those statements
are branched from a different context, such as a DDL construct linked to a
SQL statement.

Fixes: #5656
Change-Id: I568bf40adc7edcf72ea6c7fd6eb9d07790de189e
(cherry picked from commit dfcd117afca36ea3f5f574f9bef6036dae609de8)

doc/build/changelog/unreleased_13/5656.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/testing/assertions.py
test/sql/test_compiler.py

diff --git a/doc/build/changelog/unreleased_13/5656.rst b/doc/build/changelog/unreleased_13/5656.rst
new file mode 100644 (file)
index 0000000..cdec608
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 5656
+
+    Fixed structural compiler issue where some constructs such as MySQL /
+    PostgreSQL "on conflict / on duplicate key" would rely upon the state of
+    the :class:`_sql.Compiler` object being fixed against their statement as
+    the top level statement, which would fail in cases where those statements
+    are branched from a different context, such as a DDL construct linked to a
+    SQL statement.
+
index e7bf7c9b34bfffb1301be5f1a863997802c055c2..4505a83e6c0d8b863c3e5e5b181f080df6030943 100644 (file)
@@ -1383,6 +1383,8 @@ class MySQLCompiler(compiler.SQLCompiler):
         return self._render_json_extract_from_binary(binary, operator, **kw)
 
     def visit_on_duplicate_key_update(self, on_duplicate, **kw):
+        statement = self.current_executable
+
         if on_duplicate._parameter_ordering:
             parameter_ordering = [
                 elements._column_as_key(key)
@@ -1390,14 +1392,12 @@ class MySQLCompiler(compiler.SQLCompiler):
             ]
             ordered_keys = set(parameter_ordering)
             cols = [
-                self.statement.table.c[key]
+                statement.table.c[key]
                 for key in parameter_ordering
-                if key in self.statement.table.c
-            ] + [
-                c for c in self.statement.table.c if c.key not in ordered_keys
-            ]
+                if key in statement.table.c
+            ] + [c for c in statement.table.c if c.key not in ordered_keys]
         else:
-            cols = self.statement.table.c
+            cols = statement.table.c
 
         clauses = []
         # traverses through all table columns to preserve table column order
index 1295f76597b088703f67d9bb457c0e8cbcb815b6..82c05ab840f295956268d63e3539600a4fd1d1c6 100644 (file)
@@ -1951,7 +1951,7 @@ class PGCompiler(compiler.SQLCompiler):
                 "Additional column names not matching "
                 "any column keys in table '%s': %s"
                 % (
-                    self.statement.table.name,
+                    self.current_executable.table.name,
                     (", ".join("'%s'" % c for c in set_parameters)),
                 )
             )
index 3864af45b383c4c3fbd46cb7b03f20271388e521..0489881d234458b6ef8b52e81ca245a6fef189d3 100644 (file)
@@ -599,6 +599,44 @@ class SQLCompiler(Compiled):
         if self.positional and self._numeric_binds:
             self._apply_numbered_params()
 
+    @property
+    def current_executable(self):
+        """Return the current 'executable' that is being compiled.
+
+        This is currently the :class:`_sql.Select`, :class:`_sql.Insert`,
+        :class:`_sql.Update`, :class:`_sql.Delete`,
+        :class:`_sql.CompoundSelect` object that is being compiled.
+        Specifically it's assigned to the ``self.stack`` list of elements.
+
+        When a statement like the above is being compiled, it normally
+        is also assigned to the ``.statement`` attribute of the
+        :class:`_sql.Compiler` object.   However, all SQL constructs are
+        ultimately nestable, and this attribute should never be consulted
+        by a ``visit_`` method, as it is not guaranteed to be assigned
+        nor guaranteed to correspond to the current statement being compiled.
+
+        .. versionadded:: 1.3.21
+
+            For compatibility with previous versions, use the following
+            recipe::
+
+                statement = getattr(self, "current_executable", False)
+                if statement is False:
+                    statement = self.stack[-1]["selectable"]
+
+            For versions 1.4 and above, ensure only .current_executable
+            is used; the format of "self.stack" may change.
+
+
+        """
+        try:
+            return self.stack[-1]["selectable"]
+        except IndexError as ie:
+            util.raise_(
+                IndexError("Compiler does not have a stack entry"),
+                replace_context=ie,
+            )
+
     @property
     def prefetch(self):
         return list(self.insert_prefetch + self.update_prefetch)
index 2c9baf36f9c5603389daa84538ba3f4bcb150fa6..52912d8dec1ae99a5f64dabc8d707c5fa73ae0d2 100644 (file)
@@ -21,6 +21,7 @@ from .util import fail
 from .. import exc as sa_exc
 from .. import pool
 from .. import schema
+from .. import sql
 from .. import types as sqltypes
 from .. import util
 from ..engine import default
@@ -444,7 +445,60 @@ class AssertsCompiledSQL(object):
         if compile_kwargs:
             kw["compile_kwargs"] = compile_kwargs
 
-        c = clause.compile(dialect=dialect, **kw)
+        class DontAccess(object):
+            def __getattribute__(self, key):
+                raise NotImplementedError(
+                    "compiler accessed .statement; use "
+                    "compiler.current_executable"
+                )
+
+        class CheckCompilerAccess(object):
+            def __init__(self, test_statement):
+                self.test_statement = test_statement
+                self.supports_execution = getattr(
+                    test_statement, "supports_execution", False
+                )
+                if self.supports_execution:
+                    self._execution_options = test_statement._execution_options
+
+                    if isinstance(
+                        test_statement, (sql.Insert, sql.Update, sql.Delete)
+                    ):
+                        self._returning = test_statement._returning
+                    if isinstance(test_statement, (sql.Insert, sql.Update)):
+                        self.inline = test_statement.inline
+                        self._return_defaults = test_statement._return_defaults
+
+            def _default_dialect(self):
+                return self.test_statement._default_dialect()
+
+            def compile(self, dialect, **kw):
+                return self.test_statement.compile.__func__(
+                    self, dialect=dialect, **kw
+                )
+
+            def _compiler(self, dialect, **kw):
+                return self.test_statement._compiler.__func__(
+                    self, dialect, **kw
+                )
+
+            def _compiler_dispatch(self, compiler, **kwargs):
+                if hasattr(compiler, "statement"):
+                    with mock.patch.object(
+                        compiler, "statement", DontAccess()
+                    ):
+                        return self.test_statement._compiler_dispatch(
+                            compiler, **kwargs
+                        )
+                else:
+                    return self.test_statement._compiler_dispatch(
+                        compiler, **kwargs
+                    )
+
+        # no construct can assume it's the "top level" construct in all cases
+        # as anything can be nested.  ensure constructs don't assume they
+        # are the "self.statement" element
+        c = CheckCompilerAccess(clause).compile(dialect=dialect, **kw)
 
         param_str = repr(getattr(c, "params", {}))
 
index 20ab18f4247e29975275fc626c726e1e09065dbc..772dff109d481e2da3b04e485100266834d42bc4 100644 (file)
@@ -74,6 +74,7 @@ from sqlalchemy.sql import compiler
 from sqlalchemy.sql import label
 from sqlalchemy.sql import table
 from sqlalchemy.sql.expression import _literal_as_text
+from sqlalchemy.sql.expression import ClauseElement
 from sqlalchemy.sql.expression import ClauseList
 from sqlalchemy.sql.expression import HasPrefixes
 from sqlalchemy.testing import assert_raises
@@ -83,6 +84,7 @@ from sqlalchemy.testing import eq_
 from sqlalchemy.testing import eq_ignore_whitespace
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
+from sqlalchemy.testing import mock
 from sqlalchemy.util import u
 
 
@@ -146,6 +148,53 @@ keyed = Table(
 )
 
 
+class TestCompilerFixture(fixtures.TestBase, AssertsCompiledSQL):
+    def test_dont_access_statement(self):
+        def visit_foobar(self, element, **kw):
+            self.statement.table
+
+        class Foobar(ClauseElement):
+            __visit_name__ = "foobar"
+
+        with mock.patch.object(
+            testing.db.dialect.statement_compiler,
+            "visit_foobar",
+            visit_foobar,
+            create=True,
+        ):
+            assert_raises_message(
+                NotImplementedError,
+                "compiler accessed .statement; use "
+                "compiler.current_executable",
+                self.assert_compile,
+                Foobar(),
+                "",
+            )
+
+    def test_no_stack(self):
+        def visit_foobar(self, element, **kw):
+            self.current_executable.table
+
+        class Foobar(ClauseElement):
+            __visit_name__ = "foobar"
+
+        with mock.patch.object(
+            testing.db.dialect.statement_compiler,
+            "visit_foobar",
+            visit_foobar,
+            create=True,
+        ):
+            compiler = testing.db.dialect.statement_compiler(
+                testing.db.dialect, None
+            )
+            assert_raises_message(
+                IndexError,
+                "Compiler does not have a stack entry",
+                compiler.process,
+                Foobar(),
+            )
+
+
 class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = "default"