]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Ensure no compiler visit method tries to access .statement
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 19 Oct 2020 14:19:29 +0000 (10:19 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 19 Oct 2020 17:13:15 +0000 (13:13 -0400)
Fixed structural compiler issue where some constructs such as MySQL /
PostgreSQL "on conflict / on duplicate key" would rely upon the state of
the :class:`_sql.Compiler` object being fixed against their statement as
the top level statement, which would fail in cases where those statements
are branched from a different context, such as a DDL construct linked to a
SQL statement.

Fixes: #5656
Change-Id: I568bf40adc7edcf72ea6c7fd6eb9d07790de189e

doc/build/changelog/unreleased_13/5656.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/testing/assertions.py
test/sql/test_compiler.py

diff --git a/doc/build/changelog/unreleased_13/5656.rst b/doc/build/changelog/unreleased_13/5656.rst
new file mode 100644 (file)
index 0000000..cdec608
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 5656
+
+    Fixed structural compiler issue where some constructs such as MySQL /
+    PostgreSQL "on conflict / on duplicate key" would rely upon the state of
+    the :class:`_sql.Compiler` object being fixed against their statement as
+    the top level statement, which would fail in cases where those statements
+    are branched from a different context, such as a DDL construct linked to a
+    SQL statement.
+
index 3e6c676ababffcf121ec9e7179946d462cf80f8a..77f65799c2ea47bb63cfece35d9b8088c25249c2 100644 (file)
@@ -1461,6 +1461,8 @@ class MySQLCompiler(compiler.SQLCompiler):
         return self._render_json_extract_from_binary(binary, operator, **kw)
 
     def visit_on_duplicate_key_update(self, on_duplicate, **kw):
+        statement = self.current_executable
+
         if on_duplicate._parameter_ordering:
             parameter_ordering = [
                 coercions.expect(roles.DMLColumnRole, key)
@@ -1468,14 +1470,12 @@ class MySQLCompiler(compiler.SQLCompiler):
             ]
             ordered_keys = set(parameter_ordering)
             cols = [
-                self.statement.table.c[key]
+                statement.table.c[key]
                 for key in parameter_ordering
-                if key in self.statement.table.c
-            ] + [
-                c for c in self.statement.table.c if c.key not in ordered_keys
-            ]
+                if key in statement.table.c
+            ] + [c for c in statement.table.c if c.key not in ordered_keys]
         else:
-            cols = self.statement.table.c
+            cols = statement.table.c
 
         clauses = []
         # traverses through all table columns to preserve table column order
index 44272b9d35fd7200b62d7976c7e6903e24e857cb..ea6921b2db21864551b586eb0143865d275da442 100644 (file)
@@ -2101,7 +2101,7 @@ class PGCompiler(compiler.SQLCompiler):
                 "Additional column names not matching "
                 "any column keys in table '%s': %s"
                 % (
-                    self.statement.table.name,
+                    self.current_executable.table.name,
                     (", ".join("'%s'" % c for c in set_parameters)),
                 )
             )
index 022f6611f7b165abf05db6f1ebc8aa86f4a4b2d5..4a0b2d07d9ee8c0dc801d7af3edeaf72dcc3456c 100644 (file)
@@ -2171,10 +2171,9 @@ class BulkORMUpdate(UpdateDMLState, BulkUDCompileState):
 
         # if we are against a lambda statement we might not be the
         # topmost object that received per-execute annotations
-        top_level_stmt = compiler.statement
+
         if (
-            top_level_stmt._annotations.get("synchronize_session", None)
-            == "fetch"
+            compiler._annotations.get("synchronize_session", None) == "fetch"
             and compiler.dialect.full_returning
         ):
             new_stmt = new_stmt.returning(*mapper.primary_key)
@@ -2287,8 +2286,6 @@ class BulkORMDelete(DeleteDMLState, BulkUDCompileState):
         ext_info = statement.table._annotations["parententity"]
         self.mapper = mapper = ext_info.mapper
 
-        top_level_stmt = compiler.statement
-
         self.extra_criteria_entities = {}
 
         extra_criteria_attributes = {}
@@ -2305,7 +2302,7 @@ class BulkORMDelete(DeleteDMLState, BulkUDCompileState):
 
         if (
             mapper
-            and top_level_stmt._annotations.get("synchronize_session", None)
+            and compiler._annotations.get("synchronize_session", None)
             == "fetch"
             and compiler.dialect.full_returning
         ):
index 8718e15ea9b64c1f7c742dcc08648df0696ccd15..10499975c7b26edd0c2fcd406a0a5840250e06eb 100644 (file)
@@ -377,12 +377,14 @@ class Compiled(object):
 
     schema_translate_map = None
 
-    execution_options = util.immutabledict()
+    execution_options = util.EMPTY_DICT
     """
     Execution options propagated from the statement.   In some cases,
     sub-elements of the statement can modify these.
     """
 
+    _annotations = util.EMPTY_DICT
+
     compile_state = None
     """Optional :class:`.CompileState` object that maintains additional
     state used by the compiler.
@@ -474,6 +476,7 @@ class Compiled(object):
         if statement is not None:
             self.statement = statement
             self.can_execute = statement.supports_execution
+            self._annotations = statement._annotations
             if self.can_execute:
                 self.execution_options = statement._execution_options
             self.string = self.process(self.statement, **compile_kwargs)
@@ -798,6 +801,44 @@ class SQLCompiler(Compiled):
         if self._render_postcompile:
             self._process_parameters_for_postcompile(_populate_self=True)
 
+    @property
+    def current_executable(self):
+        """Return the current 'executable' that is being compiled.
+
+        This is currently the :class:`_sql.Select`, :class:`_sql.Insert`,
+        :class:`_sql.Update`, :class:`_sql.Delete`,
+        :class:`_sql.CompoundSelect` object that is being compiled.
+        Specifically it's assigned to the ``self.stack`` list of elements.
+
+        When a statement like the above is being compiled, it normally
+        is also assigned to the ``.statement`` attribute of the
+        :class:`_sql.Compiler` object.   However, all SQL constructs are
+        ultimately nestable, and this attribute should never be consulted
+        by a ``visit_`` method, as it is not guaranteed to be assigned
+        nor guaranteed to correspond to the current statement being compiled.
+
+        .. versionadded:: 1.3.21
+
+            For compatibility with previous versions, use the following
+            recipe::
+
+                statement = getattr(self, "current_executable", False)
+                if statement is False:
+                    statement = self.stack[-1]["selectable"]
+
+            For versions 1.4 and above, ensure only .current_executable
+            is used; the format of "self.stack" may change.
+
+
+        """
+        try:
+            return self.stack[-1]["selectable"]
+        except IndexError as ie:
+            util.raise_(
+                IndexError("Compiler does not have a stack entry"),
+                replace_context=ie,
+            )
+
     @property
     def prefetch(self):
         return list(self.insert_prefetch + self.update_prefetch)
index af168cd85245dc7e90d5845a1b0fe033b608f9c7..17a0acf20d436b934095759a58b1934182da771b 100644 (file)
@@ -20,6 +20,7 @@ from .exclusions import db_spec
 from .util import fail
 from .. import exc as sa_exc
 from .. import schema
+from .. import sql
 from .. import types as sqltypes
 from .. import util
 from ..engine import default
@@ -441,7 +442,61 @@ class AssertsCompiledSQL(object):
         if compile_kwargs:
             kw["compile_kwargs"] = compile_kwargs
 
-        c = clause.compile(dialect=dialect, **kw)
+        class DontAccess(object):
+            def __getattribute__(self, key):
+                raise NotImplementedError(
+                    "compiler accessed .statement; use "
+                    "compiler.current_executable"
+                )
+
+        class CheckCompilerAccess(object):
+            def __init__(self, test_statement):
+                self.test_statement = test_statement
+                self._annotations = {}
+                self.supports_execution = getattr(
+                    test_statement, "supports_execution", False
+                )
+                if self.supports_execution:
+                    self._execution_options = test_statement._execution_options
+
+                    if isinstance(
+                        test_statement, (sql.Insert, sql.Update, sql.Delete)
+                    ):
+                        self._returning = test_statement._returning
+                    if isinstance(test_statement, (sql.Insert, sql.Update)):
+                        self._inline = test_statement._inline
+                        self._return_defaults = test_statement._return_defaults
+
+            def _default_dialect(self):
+                return self.test_statement._default_dialect()
+
+            def compile(self, dialect, **kw):
+                return self.test_statement.compile.__func__(
+                    self, dialect=dialect, **kw
+                )
+
+            def _compiler(self, dialect, **kw):
+                return self.test_statement._compiler.__func__(
+                    self, dialect, **kw
+                )
+
+            def _compiler_dispatch(self, compiler, **kwargs):
+                if hasattr(compiler, "statement"):
+                    with mock.patch.object(
+                        compiler, "statement", DontAccess()
+                    ):
+                        return self.test_statement._compiler_dispatch(
+                            compiler, **kwargs
+                        )
+                else:
+                    return self.test_statement._compiler_dispatch(
+                        compiler, **kwargs
+                    )
+
+        # no construct can assume it's the "top level" construct in all cases
+        # as anything can be nested.  ensure constructs don't assume they
+        # are the "self.statement" element
+        c = CheckCompilerAccess(clause).compile(dialect=dialect, **kw)
 
         param_str = repr(getattr(c, "params", {}))
         if util.py3k:
index ef2f75b2d6caeb076d2c9d2b7500039abe2fbd02..c9e1d9ab419c30a9f7cce10de66f641731274868 100644 (file)
@@ -77,6 +77,7 @@ from sqlalchemy.sql import operators
 from sqlalchemy.sql import table
 from sqlalchemy.sql import util as sql_util
 from sqlalchemy.sql.elements import BooleanClauseList
+from sqlalchemy.sql.expression import ClauseElement
 from sqlalchemy.sql.expression import ClauseList
 from sqlalchemy.sql.expression import HasPrefixes
 from sqlalchemy.testing import assert_raises
@@ -158,6 +159,53 @@ keyed = Table(
 )
 
 
+class TestCompilerFixture(fixtures.TestBase, AssertsCompiledSQL):
+    def test_dont_access_statement(self):
+        def visit_foobar(self, element, **kw):
+            self.statement.table
+
+        class Foobar(ClauseElement):
+            __visit_name__ = "foobar"
+
+        with mock.patch.object(
+            testing.db.dialect.statement_compiler,
+            "visit_foobar",
+            visit_foobar,
+            create=True,
+        ):
+            assert_raises_message(
+                NotImplementedError,
+                "compiler accessed .statement; use "
+                "compiler.current_executable",
+                self.assert_compile,
+                Foobar(),
+                "",
+            )
+
+    def test_no_stack(self):
+        def visit_foobar(self, element, **kw):
+            self.current_executable.table
+
+        class Foobar(ClauseElement):
+            __visit_name__ = "foobar"
+
+        with mock.patch.object(
+            testing.db.dialect.statement_compiler,
+            "visit_foobar",
+            visit_foobar,
+            create=True,
+        ):
+            compiler = testing.db.dialect.statement_compiler(
+                testing.db.dialect, None
+            )
+            assert_raises_message(
+                IndexError,
+                "Compiler does not have a stack entry",
+                compiler.process,
+                Foobar(),
+            )
+
+
 class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = "default"