]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
ensure all visit methods accept **kw
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 16 Dec 2022 17:16:21 +0000 (12:16 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 16 Dec 2022 18:47:43 +0000 (13:47 -0500)
Added test support to ensure that all compiler ``visit_xyz()`` methods
across all :class:`.Compiler` implementations in SQLAlchemy accept a
``**kw`` parameter, so that all compilers accept additional keyword
arguments under all circumstances.

Fixes: #8988
Change-Id: I1cefc313e4e64a10ee7dd14400137fbe02ce9523

doc/build/changelog/unreleased_20/8988.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/testing/suite/test_dialect.py

diff --git a/doc/build/changelog/unreleased_20/8988.rst b/doc/build/changelog/unreleased_20/8988.rst
new file mode 100644 (file)
index 0000000..b5300c1
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 8988
+
+    Added test support to ensure that all compiler ``visit_xyz()`` methods
+    across all :class:`.Compiler` implementations in SQLAlchemy accept a
+    ``**kw`` parameter, so that all compilers accept additional keyword
+    arguments under all circumstances.
index a0049c361e00a2d606e3728e2ad5d02298f3879f..5ecb4c512fc35ccfe837b6729011b7db2230fd18 100644 (file)
@@ -2236,12 +2236,12 @@ class MSSQLCompiler(compiler.SQLCompiler):
         field = self.extract_map.get(extract.field, extract.field)
         return "DATEPART(%s, %s)" % (field, self.process(extract.expr, **kw))
 
-    def visit_savepoint(self, savepoint_stmt):
+    def visit_savepoint(self, savepoint_stmt, **kw):
         return "SAVE TRANSACTION %s" % self.preparer.format_savepoint(
             savepoint_stmt
         )
 
-    def visit_rollback_to_savepoint(self, savepoint_stmt):
+    def visit_rollback_to_savepoint(self, savepoint_stmt, **kw):
         return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(
             savepoint_stmt
         )
@@ -2392,7 +2392,7 @@ class MSSQLCompiler(compiler.SQLCompiler):
             for t in [from_table] + extra_froms
         )
 
-    def visit_empty_set_expr(self, type_):
+    def visit_empty_set_expr(self, type_, **kw):
         return "SELECT 1 WHERE 1!=1"
 
     def visit_is_distinct_from_binary(self, binary, operator, **kw):
@@ -2580,7 +2580,7 @@ class MSDDLCompiler(compiler.DDLCompiler):
 
         return colspec
 
-    def visit_create_index(self, create, include_schema=False):
+    def visit_create_index(self, create, include_schema=False, **kw):
         index = create.element
         self._verify_index_table(index)
         preparer = self.preparer
@@ -2632,13 +2632,13 @@ class MSDDLCompiler(compiler.DDLCompiler):
 
         return text
 
-    def visit_drop_index(self, drop):
+    def visit_drop_index(self, drop, **kw):
         return "\nDROP INDEX %s ON %s" % (
             self._prepared_index_name(drop.element, include_schema=False),
             self.preparer.format_table(drop.element.table),
         )
 
-    def visit_primary_key_constraint(self, constraint):
+    def visit_primary_key_constraint(self, constraint, **kw):
         if len(constraint) == 0:
             return ""
         text = ""
@@ -2661,7 +2661,7 @@ class MSDDLCompiler(compiler.DDLCompiler):
         text += self.define_constraint_deferrability(constraint)
         return text
 
-    def visit_unique_constraint(self, constraint):
+    def visit_unique_constraint(self, constraint, **kw):
         if len(constraint) == 0:
             return ""
         text = ""
@@ -2684,7 +2684,7 @@ class MSDDLCompiler(compiler.DDLCompiler):
         text += self.define_constraint_deferrability(constraint)
         return text
 
-    def visit_computed_column(self, generated):
+    def visit_computed_column(self, generated, **kw):
         text = "AS (%s)" % self.sql_compiler.process(
             generated.sqltext, include_table=False, literal_binds=True
         )
@@ -2693,7 +2693,7 @@ class MSDDLCompiler(compiler.DDLCompiler):
             text += " PERSISTED"
         return text
 
-    def visit_set_table_comment(self, create):
+    def visit_set_table_comment(self, create, **kw):
         schema = self.preparer.schema_for_object(create.element)
         schema_name = schema if schema else self.dialect.default_schema_name
         return (
@@ -2707,7 +2707,7 @@ class MSDDLCompiler(compiler.DDLCompiler):
             )
         )
 
-    def visit_drop_table_comment(self, drop):
+    def visit_drop_table_comment(self, drop, **kw):
         schema = self.preparer.schema_for_object(drop.element)
         schema_name = schema if schema else self.dialect.default_schema_name
         return (
@@ -2718,7 +2718,7 @@ class MSDDLCompiler(compiler.DDLCompiler):
             )
         )
 
-    def visit_set_column_comment(self, create):
+    def visit_set_column_comment(self, create, **kw):
         schema = self.preparer.schema_for_object(create.element.table)
         schema_name = schema if schema else self.dialect.default_schema_name
         return (
@@ -2735,7 +2735,7 @@ class MSDDLCompiler(compiler.DDLCompiler):
             )
         )
 
-    def visit_drop_column_comment(self, drop):
+    def visit_drop_column_comment(self, drop, **kw):
         schema = self.preparer.schema_for_object(drop.element.table)
         schema_name = schema if schema else self.dialect.default_schema_name
         return (
index 2525c6c32e257ac382ddb92d6218226d5e3a695c..f965eac159516f0614afa8ea132558ef0be67d72 100644 (file)
@@ -1663,7 +1663,7 @@ class MySQLCompiler(compiler.SQLCompiler):
             for t in [from_table] + extra_froms
         )
 
-    def visit_empty_set_expr(self, element_types):
+    def visit_empty_set_expr(self, element_types, **kw):
         return (
             "SELECT %(outer)s FROM (SELECT %(inner)s) "
             "as _empty_set WHERE 1!=1"
@@ -1962,14 +1962,14 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
 
         return text
 
-    def visit_primary_key_constraint(self, constraint):
+    def visit_primary_key_constraint(self, constraint, **kw):
         text = super().visit_primary_key_constraint(constraint)
         using = constraint.dialect_options["mysql"]["using"]
         if using:
             text += " USING %s" % (self.preparer.quote(using))
         return text
 
-    def visit_drop_index(self, drop):
+    def visit_drop_index(self, drop, **kw):
         index = drop.element
         text = "\nDROP INDEX "
         if drop.if_exists:
@@ -1980,7 +1980,7 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
             self.preparer.format_table(index.table),
         )
 
-    def visit_drop_constraint(self, drop):
+    def visit_drop_constraint(self, drop, **kw):
         constraint = drop.element
         if isinstance(constraint, sa_schema.ForeignKeyConstraint):
             qual = "FOREIGN KEY "
@@ -2014,7 +2014,7 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
             )
         return ""
 
-    def visit_set_table_comment(self, create):
+    def visit_set_table_comment(self, create, **kw):
         return "ALTER TABLE %s COMMENT %s" % (
             self.preparer.format_table(create.element),
             self.sql_compiler.render_literal_value(
@@ -2022,12 +2022,12 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
             ),
         )
 
-    def visit_drop_table_comment(self, create):
+    def visit_drop_table_comment(self, create, **kw):
         return "ALTER TABLE %s COMMENT ''" % (
             self.preparer.format_table(create.element)
         )
 
-    def visit_set_column_comment(self, create):
+    def visit_set_column_comment(self, create, **kw):
         return "ALTER TABLE %s CHANGE %s %s" % (
             self.preparer.format_table(create.element.table),
             self.preparer.format_column(create.element),
index dc2b011afeb95e11a4afe0f111ba05bc379a60bd..d6f65e5ed5a94571f946f17c0f6b53702d0d21ab 100644 (file)
@@ -1189,7 +1189,7 @@ class OracleCompiler(compiler.SQLCompiler):
     def limit_clause(self, select, **kw):
         return ""
 
-    def visit_empty_set_expr(self, type_):
+    def visit_empty_set_expr(self, type_, **kw):
         return "SELECT 1 FROM DUAL WHERE 1!=1"
 
     def for_update_clause(self, select, **kw):
@@ -1279,12 +1279,12 @@ class OracleDDLCompiler(compiler.DDLCompiler):
 
         return text
 
-    def visit_drop_table_comment(self, drop):
+    def visit_drop_table_comment(self, drop, **kw):
         return "COMMENT ON TABLE %s IS ''" % self.preparer.format_table(
             drop.element
         )
 
-    def visit_create_index(self, create):
+    def visit_create_index(self, create, **kw):
         index = create.element
         self._verify_index_table(index)
         preparer = self.preparer
@@ -1336,7 +1336,7 @@ class OracleDDLCompiler(compiler.DDLCompiler):
         text = text.replace("NO ORDER", "NOORDER")
         return text
 
-    def visit_computed_column(self, generated):
+    def visit_computed_column(self, generated, **kw):
         text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process(
             generated.sqltext, include_table=False, literal_binds=True
         )
index 8287e828a74c5608b8da1dd587e9a172b02ff905..3fb29812b949e01023f5ada2ace49204ed1521eb 100644 (file)
@@ -1836,7 +1836,7 @@ class PGCompiler(compiler.SQLCompiler):
                 self.process(flags, **kw),
             )
 
-    def visit_empty_set_expr(self, element_types):
+    def visit_empty_set_expr(self, element_types, **kw):
         # cast the empty set to the type we are comparing against.  if
         # we are comparing against the null type, pick an arbitrary
         # datatype for the empty set
@@ -2144,7 +2144,7 @@ class PGDDLCompiler(compiler.DDLCompiler):
         not_valid = constraint.dialect_options["postgresql"]["not_valid"]
         return " NOT VALID" if not_valid else ""
 
-    def visit_check_constraint(self, constraint):
+    def visit_check_constraint(self, constraint, **kw):
         if constraint._type_bound:
             typ = list(constraint.columns)[0].type
             if (
@@ -2162,12 +2162,12 @@ class PGDDLCompiler(compiler.DDLCompiler):
         text += self._define_constraint_validity(constraint)
         return text
 
-    def visit_foreign_key_constraint(self, constraint):
+    def visit_foreign_key_constraint(self, constraint, **kw):
         text = super().visit_foreign_key_constraint(constraint)
         text += self._define_constraint_validity(constraint)
         return text
 
-    def visit_create_enum_type(self, create):
+    def visit_create_enum_type(self, create, **kw):
         type_ = create.element
 
         return "CREATE TYPE %s AS ENUM (%s)" % (
@@ -2178,12 +2178,12 @@ class PGDDLCompiler(compiler.DDLCompiler):
             ),
         )
 
-    def visit_drop_enum_type(self, drop):
+    def visit_drop_enum_type(self, drop, **kw):
         type_ = drop.element
 
         return "DROP TYPE %s" % (self.preparer.format_type(type_))
 
-    def visit_create_domain_type(self, create):
+    def visit_create_domain_type(self, create, **kw):
         domain: DOMAIN = create.element
 
         options = []
@@ -2211,11 +2211,11 @@ class PGDDLCompiler(compiler.DDLCompiler):
             f"{' '.join(options)}"
         )
 
-    def visit_drop_domain_type(self, drop):
+    def visit_drop_domain_type(self, drop, **kw):
         domain = drop.element
         return f"DROP DOMAIN {self.preparer.format_type(domain)}"
 
-    def visit_create_index(self, create):
+    def visit_create_index(self, create, **kw):
         preparer = self.preparer
         index = create.element
         self._verify_index_table(index)
@@ -2303,7 +2303,7 @@ class PGDDLCompiler(compiler.DDLCompiler):
 
         return text
 
-    def visit_drop_index(self, drop):
+    def visit_drop_index(self, drop, **kw):
         index = drop.element
 
         text = "\nDROP INDEX "
@@ -2382,7 +2382,7 @@ class PGDDLCompiler(compiler.DDLCompiler):
 
         return "".join(table_opts)
 
-    def visit_computed_column(self, generated):
+    def visit_computed_column(self, generated, **kw):
         if generated.persisted is False:
             raise exc.CompileError(
                 "PostrgreSQL computed columns do not support 'virtual' "
index 5d8b3fbad8bc08bc4bd834a63abc06639a1f80fb..5a0761e5fe5d1c9c733ca1bd9d06a1baf7b1a6f4 100644 (file)
@@ -1423,12 +1423,12 @@ class SQLiteCompiler(compiler.SQLCompiler):
             self.process(binary.right, **kw),
         )
 
-    def visit_empty_set_op_expr(self, type_, expand_op):
+    def visit_empty_set_op_expr(self, type_, expand_op, **kw):
         # slightly old SQLite versions don't seem to be able to handle
         # the empty set impl
         return self.visit_empty_set_expr(type_)
 
-    def visit_empty_set_expr(self, element_types):
+    def visit_empty_set_expr(self, element_types, **kw):
         return "SELECT %s FROM (SELECT %s) WHERE 1!=1" % (
             ", ".join("1" for type_ in element_types or [INTEGER()]),
             ", ".join("1" for type_ in element_types or [INTEGER()]),
@@ -1595,7 +1595,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler):
 
         return colspec
 
-    def visit_primary_key_constraint(self, constraint):
+    def visit_primary_key_constraint(self, constraint, **kw):
         # for columns with sqlite_autoincrement=True,
         # the PRIMARY KEY constraint can only be inline
         # with the column itself.
@@ -1624,7 +1624,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler):
 
         return text
 
-    def visit_unique_constraint(self, constraint):
+    def visit_unique_constraint(self, constraint, **kw):
         text = super().visit_unique_constraint(constraint)
 
         on_conflict_clause = constraint.dialect_options["sqlite"][
@@ -1642,7 +1642,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler):
 
         return text
 
-    def visit_check_constraint(self, constraint):
+    def visit_check_constraint(self, constraint, **kw):
         text = super().visit_check_constraint(constraint)
 
         on_conflict_clause = constraint.dialect_options["sqlite"][
@@ -1654,7 +1654,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler):
 
         return text
 
-    def visit_column_check_constraint(self, constraint):
+    def visit_column_check_constraint(self, constraint, **kw):
         text = super().visit_column_check_constraint(constraint)
 
         if constraint.dialect_options["sqlite"]["on_conflict"] is not None:
@@ -1665,7 +1665,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler):
 
         return text
 
-    def visit_foreign_key_constraint(self, constraint):
+    def visit_foreign_key_constraint(self, constraint, **kw):
 
         local_table = constraint.elements[0].parent.table
         remote_table = constraint.elements[0].column.table
@@ -1681,7 +1681,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler):
         return preparer.format_table(table, use_schema=False)
 
     def visit_create_index(
-        self, create, include_schema=False, include_table_schema=True
+        self, create, include_schema=False, include_table_schema=True, **kw
     ):
         index = create.element
         self._verify_index_table(index)
index 66a294d1061bf98cc349a3313074e2e91a597697..43bccfd3a14cb5c1ac7bf8b8fe4d6f3018dbaa45 100644 (file)
@@ -724,7 +724,7 @@ class Compiled:
         else:
             raise exc.ObjectNotExecutableError(self.statement)
 
-    def visit_unsupported_compilation(self, element, err):
+    def visit_unsupported_compilation(self, element, err, **kw):
         raise exc.UnsupportedCompilationError(self, type(element)) from err
 
     @property
@@ -2846,7 +2846,7 @@ class SQLCompiler(Compiled):
             binary, OPERATORS[operator], **kw
         )
 
-    def visit_empty_set_op_expr(self, type_, expand_op):
+    def visit_empty_set_op_expr(self, type_, expand_op, **kw):
         if expand_op is operators.not_in_op:
             if len(type_) > 1:
                 return "(%s)) OR (1 = 1" % (
@@ -2864,7 +2864,7 @@ class SQLCompiler(Compiled):
         else:
             return self.visit_empty_set_expr(type_)
 
-    def visit_empty_set_expr(self, element_types):
+    def visit_empty_set_expr(self, element_types, **kw):
         raise NotImplementedError(
             "Dialect '%s' does not support empty set expression."
             % self.dialect.name
@@ -5624,15 +5624,15 @@ class SQLCompiler(Compiled):
 
         return text
 
-    def visit_savepoint(self, savepoint_stmt):
+    def visit_savepoint(self, savepoint_stmt, **kw):
         return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
 
-    def visit_rollback_to_savepoint(self, savepoint_stmt):
+    def visit_rollback_to_savepoint(self, savepoint_stmt, **kw):
         return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(
             savepoint_stmt
         )
 
-    def visit_release_savepoint(self, savepoint_stmt):
+    def visit_release_savepoint(self, savepoint_stmt, **kw):
         return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(
             savepoint_stmt
         )
@@ -5720,7 +5720,7 @@ class StrSQLCompiler(SQLCompiler):
             for t in extra_froms
         )
 
-    def visit_empty_set_expr(self, type_):
+    def visit_empty_set_expr(self, type_, **kw):
         return "SELECT 1 WHERE 1!=1"
 
     def get_from_hint_text(self, table, text):
index 945edef85b1bec8682357ee65df2d77c7fd782bb..38fe8f9c46d69d54e1e0a37fcdec0a032bdafeb6 100644 (file)
@@ -1,12 +1,15 @@
 # mypy: ignore-errors
 
 
+import importlib
+
 from . import testing
 from .. import assert_raises
 from .. import config
 from .. import engines
 from .. import eq_
 from .. import fixtures
+from .. import is_not_none
 from .. import is_true
 from .. import ne_
 from .. import provide_metadata
@@ -17,12 +20,15 @@ from ..provision import set_default_schema_on_connection
 from ..schema import Column
 from ..schema import Table
 from ... import bindparam
+from ... import dialects
 from ... import event
 from ... import exc
 from ... import Integer
 from ... import literal_column
 from ... import select
 from ... import String
+from ...sql.compiler import Compiled
+from ...util import inspect_getfullargspec
 
 
 class PingTest(fixtures.TestBase):
@@ -35,6 +41,51 @@ class PingTest(fixtures.TestBase):
             )
 
 
+class ArgSignatureTest(fixtures.TestBase):
+    """test that all visit_XYZ() in :class:`_sql.Compiler` subclasses have
+    ``**kw``, for #8988.
+
+    This test uses runtime code inspection.   Does not need to be a
+    ``__backend__`` test as it only needs to run once provided all target
+    dialects have been imported.
+
+    For third party dialects, the suite would be run with that third
+    party as a "--dburi", which means its compiler classes will have been
+    imported by the time this test runs.
+
+    """
+
+    def _all_subclasses():  # type: ignore  # noqa
+        for d in dialects.__all__:
+            if not d.startswith("_"):
+                importlib.import_module("sqlalchemy.dialects.%s" % d)
+
+        stack = [Compiled]
+
+        while stack:
+            cls = stack.pop(0)
+            stack.extend(cls.__subclasses__())
+            yield cls
+
+    @testing.fixture(params=list(_all_subclasses()))
+    def all_subclasses(self, request):
+        yield request.param
+
+    def test_all_visit_methods_accept_kw(self, all_subclasses):
+        cls = all_subclasses
+
+        for k in cls.__dict__:
+            if k.startswith("visit_"):
+                meth = getattr(cls, k)
+
+                insp = inspect_getfullargspec(meth)
+                is_not_none(
+                    insp.varkw,
+                    f"Compiler visit method {cls.__name__}.{k}() does "
+                    "not accommodate for **kw in its argument signature",
+                )
+
+
 class ExceptionTest(fixtures.TablesTest):
     """Test basic exception wrapping.