]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Rewrite positional handling, test for "numeric"
authorFederico Caselli <cfederico87@gmail.com>
Fri, 2 Dec 2022 16:58:40 +0000 (11:58 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 5 Dec 2022 14:59:01 +0000 (09:59 -0500)
Changed how the positional compilation is performed. It's rendered by the compiler
the same as the pyformat compilation. The string is then processed to replace
the placeholders with the correct ones, and to obtain the correct order of the
parameters.
This vastly simplifies the computation of the order of the parameters, that in
case of nested CTE is very hard to compute correctly.

Reworked how numeric paramstyle behavers:
- added support for repeated parameter, without duplicating them like in normal
positional dialects
- implement insertmany support. This requires that the dialect supports out of
order placehoders, since all parameters that are not part of the VALUES clauses
are placed at the beginning of the parameter tuple
- support for different identifiers for a numeric parameter. It's for example
possible to use postgresql style placeholder $1, $2, etc

Added two new dialect based on sqlite to test "numeric" fully using
both :1 style and $1 style. Includes a workaround for SQLite's
not-really-correct numeric implementation.

Changed parmstyle of asyncpg dialect to use numeric, rendering with its native
$ identifiers

Fixes: #8926
Fixes: #8849
Change-Id: I7c640467d49adfe6d795cc84296fc7403dcad4d6

28 files changed:
doc/build/changelog/unreleased_20/8849.rst [new file with mode: 0644]
doc/build/changelog/unreleased_20/8926.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/dialects/sqlite/provision.py
lib/sqlalchemy/dialects/sqlite/pysqlite.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/interfaces.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/crud.py
lib/sqlalchemy/testing/assertsql.py
lib/sqlalchemy/testing/config.py
lib/sqlalchemy/testing/plugin/plugin_base.py
setup.cfg
test/dialect/postgresql/test_query.py
test/dialect/test_sqlite.py
test/engine/test_logging.py
test/orm/dml/test_bulk_statements.py
test/orm/test_dynamic.py
test/orm/test_merge.py
test/orm/test_unitofworkv2.py
test/requirements.py
test/sql/test_compiler.py
test/sql/test_cte.py
test/sql/test_insert.py
test/sql/test_resultset.py
test/sql/test_types.py
test/sql/test_update.py
tox.ini

diff --git a/doc/build/changelog/unreleased_20/8849.rst b/doc/build/changelog/unreleased_20/8849.rst
new file mode 100644 (file)
index 0000000..29ecf2a
--- /dev/null
@@ -0,0 +1,14 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 8849
+
+    Reworked how numeric paramstyle behavers, in particular, fixed insertmany
+    behaviour that prior to this was non functional; added support for repeated
+    parameter without duplicating them like in other positional dialects;
+    introduced new numeric paramstyle called ``numeric_dollar`` that can be
+    used to render statements that use the PostgreSQL placeholder style (
+    i.e. ``$1, $2, $3``).
+    This change requires that the dialect supports out of order placehoders,
+    that may be used used in the statements, in particular when using
+    insert-many values with statement that have parameters in the returning
+    clause.
diff --git a/doc/build/changelog/unreleased_20/8926.rst b/doc/build/changelog/unreleased_20/8926.rst
new file mode 100644 (file)
index 0000000..a0000fb
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: asyncpg
+    :tickets: 8926
+
+    Changed the paramstyle used by asyncpg from ``format`` to
+    ``numeric_dollar``. This has two main benefits since it does not require
+    additional processing of the statement and allows for duplicate parameters
+    to be present in the statements.
index 751dc3dcf39ac68bca7b8f962dd3b65ea299db50..b8f614eba5edfb692ad4218593c2bb88c731035d 100644 (file)
@@ -438,9 +438,6 @@ class AsyncAdapt_asyncpg_cursor:
     def _handle_exception(self, error):
         self._adapt_connection._handle_exception(error)
 
-    def _parameter_placeholders(self, params):
-        return tuple(f"${idx:d}" for idx, _ in enumerate(params, 1))
-
     async def _prepare_and_execute(self, operation, parameters):
         adapt_connection = self._adapt_connection
 
@@ -449,11 +446,7 @@ class AsyncAdapt_asyncpg_cursor:
             if not adapt_connection._started:
                 await adapt_connection._start_transaction()
 
-            if parameters is not None:
-                operation = operation % self._parameter_placeholders(
-                    parameters
-                )
-            else:
+            if parameters is None:
                 parameters = ()
 
             try:
@@ -506,10 +499,6 @@ class AsyncAdapt_asyncpg_cursor:
             if not adapt_connection._started:
                 await adapt_connection._start_transaction()
 
-            operation = operation % self._parameter_placeholders(
-                seq_of_parameters[0]
-            )
-
             try:
                 return await self._connection.executemany(
                     operation, seq_of_parameters
@@ -808,7 +797,7 @@ class AsyncAdaptFallback_asyncpg_connection(AsyncAdapt_asyncpg_connection):
 class AsyncAdapt_asyncpg_dbapi:
     def __init__(self, asyncpg):
         self.asyncpg = asyncpg
-        self.paramstyle = "format"
+        self.paramstyle = "numeric_dollar"
 
     def connect(self, *arg, **kw):
         async_fallback = kw.pop("async_fallback", False)
@@ -900,7 +889,7 @@ class PGDialect_asyncpg(PGDialect):
     render_bind_cast = True
     has_terminate = True
 
-    default_paramstyle = "format"
+    default_paramstyle = "numeric_dollar"
     supports_sane_multi_rowcount = False
     execution_ctx_cls = PGExecutionContext_asyncpg
     statement_compiler = PGCompiler_asyncpg
index 05ee6c625938f1cd644a97dee8d98e83c43967e5..851f0951fc0efd9d8f53d16648f367d1e0cf4455 100644 (file)
@@ -18,7 +18,13 @@ from ...testing.provision import upsert
 
 
 # TODO: I can't get this to build dynamically with pytest-xdist procs
-_drivernames = {"pysqlite", "aiosqlite", "pysqlcipher"}
+_drivernames = {
+    "pysqlite",
+    "aiosqlite",
+    "pysqlcipher",
+    "pysqlite_numeric",
+    "pysqlite_dollar",
+}
 
 
 @generate_driver_url.for_db("sqlite")
index 4475ccae7a8e812c776b2db42c35d4d3269e16be..c04a3601dc36fed95804ea54baa781c861478d60 100644 (file)
@@ -637,3 +637,110 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
 
 
 dialect = SQLiteDialect_pysqlite
+
+
+class _SQLiteDialect_pysqlite_numeric(SQLiteDialect_pysqlite):
+    """numeric dialect for testing only
+
+    internal use only.  This dialect is **NOT** supported by SQLAlchemy
+    and may change at any time.
+
+    """
+
+    supports_statement_cache = True
+    default_paramstyle = "numeric"
+    driver = "pysqlite_numeric"
+
+    _first_bind = ":1"
+    _not_in_statement_regexp = None
+
+    def __init__(self, *arg, **kw):
+        kw.setdefault("paramstyle", "numeric")
+        super().__init__(*arg, **kw)
+
+    def create_connect_args(self, url):
+        arg, opts = super().create_connect_args(url)
+        opts["factory"] = self._fix_sqlite_issue_99953()
+        return arg, opts
+
+    def _fix_sqlite_issue_99953(self):
+        import sqlite3
+
+        first_bind = self._first_bind
+        if self._not_in_statement_regexp:
+            nis = self._not_in_statement_regexp
+
+            def _test_sql(sql):
+                m = nis.search(sql)
+                assert not m, f"Found {nis.pattern!r} in {sql!r}"
+
+        else:
+
+            def _test_sql(sql):
+                pass
+
+        def _numeric_param_as_dict(parameters):
+            if parameters:
+                assert isinstance(parameters, tuple)
+                return {
+                    str(idx): value for idx, value in enumerate(parameters, 1)
+                }
+            else:
+                return ()
+
+        class SQLiteFix99953Cursor(sqlite3.Cursor):
+            def execute(self, sql, parameters=()):
+                _test_sql(sql)
+                if first_bind in sql:
+                    parameters = _numeric_param_as_dict(parameters)
+                return super().execute(sql, parameters)
+
+            def executemany(self, sql, parameters):
+                _test_sql(sql)
+                if first_bind in sql:
+                    parameters = [
+                        _numeric_param_as_dict(p) for p in parameters
+                    ]
+                return super().executemany(sql, parameters)
+
+        class SQLiteFix99953Connection(sqlite3.Connection):
+            def cursor(self, factory=None):
+                if factory is None:
+                    factory = SQLiteFix99953Cursor
+                return super().cursor(factory=factory)
+
+            def execute(self, sql, parameters=()):
+                _test_sql(sql)
+                if first_bind in sql:
+                    parameters = _numeric_param_as_dict(parameters)
+                return super().execute(sql, parameters)
+
+            def executemany(self, sql, parameters):
+                _test_sql(sql)
+                if first_bind in sql:
+                    parameters = [
+                        _numeric_param_as_dict(p) for p in parameters
+                    ]
+                return super().executemany(sql, parameters)
+
+        return SQLiteFix99953Connection
+
+
+class _SQLiteDialect_pysqlite_dollar(_SQLiteDialect_pysqlite_numeric):
+    """numeric dialect that uses $ for testing only
+
+    internal use only.  This dialect is **NOT** supported by SQLAlchemy
+    and may change at any time.
+
+    """
+
+    supports_statement_cache = True
+    default_paramstyle = "numeric_dollar"
+    driver = "pysqlite_dollar"
+
+    _first_bind = "$1"
+    _not_in_statement_regexp = re.compile(r"[^\d]:\d+")
+
+    def __init__(self, *arg, **kw):
+        kw.setdefault("paramstyle", "numeric_dollar")
+        super().__init__(*arg, **kw)
index 3cc9cab8b39efd8f28a74093c4bc6591fe5c981d..4647c84d108315fb16a21a56331571e9be29be71 100644 (file)
@@ -320,7 +320,12 @@ class DefaultDialect(Dialect):
             self.paramstyle = self.dbapi.paramstyle
         else:
             self.paramstyle = self.default_paramstyle
-        self.positional = self.paramstyle in ("qmark", "format", "numeric")
+        self.positional = self.paramstyle in (
+            "qmark",
+            "format",
+            "numeric",
+            "numeric_dollar",
+        )
         self.identifier_preparer = self.preparer(self)
         self._on_connect_isolation_level = isolation_level
 
index 2f5efce259389e6ef4fac6ff27b5bce36c8cb699..ddf8a53fb4ea531a2ffb9d29bd363867049ab4b7 100644 (file)
@@ -255,7 +255,9 @@ SchemaTranslateMapType = Mapping[Optional[str], Optional[str]]
 
 _ImmutableExecuteOptions = immutabledict[str, Any]
 
-_ParamStyle = Literal["qmark", "numeric", "named", "format", "pyformat"]
+_ParamStyle = Literal[
+    "qmark", "numeric", "named", "format", "pyformat", "numeric_dollar"
+]
 
 _GenericSetInputSizesType = List[Tuple[str, Any, "TypeEngine[Any]"]]
 
index 7ac279ee2eb1330114e84057d5ac630a5c918f7c..d7358ad3befbb749a66bb06290f17a57c39142c0 100644 (file)
@@ -227,11 +227,13 @@ FK_INITIALLY = re.compile(r"^(?:DEFERRED|IMMEDIATE)$", re.I)
 BIND_PARAMS = re.compile(r"(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])", re.UNICODE)
 BIND_PARAMS_ESC = re.compile(r"\x5c(:[\w\$]*)(?![:\w\$])", re.UNICODE)
 
+_pyformat_template = "%%(%(name)s)s"
 BIND_TEMPLATES = {
-    "pyformat": "%%(%(name)s)s",
+    "pyformat": _pyformat_template,
     "qmark": "?",
     "format": "%%s",
     "numeric": ":[_POSITION]",
+    "numeric_dollar": "$[_POSITION]",
     "named": ":%(name)s",
 }
 
@@ -420,6 +422,22 @@ class _InsertManyValues(NamedTuple):
     num_positional_params_counted: int
 
 
+class CompilerState(IntEnum):
+    COMPILING = 0
+    """statement is present, compilation phase in progress"""
+
+    STRING_APPLIED = 1
+    """statement is present, string form of the statement has been applied.
+
+    Additional processors by subclasses may still be pending.
+
+    """
+
+    NO_STATEMENT = 2
+    """compiler does not have a statement to compile, is used
+    for method access"""
+
+
 class Linting(IntEnum):
     """represent preferences for the 'SQL linting' feature.
 
@@ -527,6 +545,14 @@ class Compiled:
     defaults.
     """
 
+    statement: Optional[ClauseElement] = None
+    "The statement to compile."
+    string: str = ""
+    "The string representation of the ``statement``"
+
+    state: CompilerState
+    """description of the compiler's state"""
+
     is_sql = False
     is_ddl = False
 
@@ -618,7 +644,6 @@ class Compiled:
 
 
         """
-
         self.dialect = dialect
         self.preparer = self.dialect.identifier_preparer
         if schema_translate_map:
@@ -628,6 +653,7 @@ class Compiled:
             )
 
         if statement is not None:
+            self.state = CompilerState.COMPILING
             self.statement = statement
             self.can_execute = statement.supports_execution
             self._annotations = statement._annotations
@@ -641,6 +667,11 @@ class Compiled:
                 self.string = self.preparer._render_schema_translates(
                     self.string, schema_translate_map
                 )
+
+            self.state = CompilerState.STRING_APPLIED
+        else:
+            self.state = CompilerState.NO_STATEMENT
+
         self._gen_time = perf_counter()
 
     def _execute_on_connection(
@@ -672,7 +703,10 @@ class Compiled:
     def __str__(self) -> str:
         """Return the string text of the generated SQL or DDL."""
 
-        return self.string or ""
+        if self.state is CompilerState.STRING_APPLIED:
+            return self.string
+        else:
+            return ""
 
     def construct_params(
         self,
@@ -859,6 +893,19 @@ class SQLCompiler(Compiled):
     driver/DB enforces this
     """
 
+    bindtemplate: str
+    """template to render bound parameters based on paramstyle."""
+
+    compilation_bindtemplate: str
+    """template used by compiler to render parameters before positional
+    paramstyle application"""
+
+    _numeric_binds_identifier_char: str
+    """Character that's used to as the identifier of a numerical bind param.
+    For example if this char is set to ``$``, numerical binds will be rendered
+    in the form ``$1, $2, $3``.
+    """
+
     _result_columns: List[ResultColumnsEntry]
     """relates label names in the final SQL to a tuple of local
     column/label name, ColumnElement object (if any) and
@@ -967,13 +1014,17 @@ class SQLCompiler(Compiled):
     and is combined with the :attr:`_sql.Compiled.params` dictionary to
     render parameters.
 
+    This sequence always contains the unescaped name of the parameters.
+
     .. seealso::
 
         :ref:`faq_sql_expression_string` - includes a usage example for
         debugging use cases.
 
     """
-    positiontup_level: Optional[Dict[str, int]] = None
+    _values_bindparam: Optional[List[str]] = None
+
+    _visited_bindparam: Optional[List[str]] = None
 
     inline: bool = False
 
@@ -988,9 +1039,12 @@ class SQLCompiler(Compiled):
     level_name_by_cte: Dict[CTE, Tuple[int, str, selectable._CTEOpts]]
 
     ctes_recursive: bool
-    cte_positional: Dict[CTE, List[str]]
-    cte_level: Dict[CTE, int]
-    cte_order: Dict[Optional[CTE], List[CTE]]
+
+    _post_compile_pattern = re.compile(r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]")
+    _pyformat_pattern = re.compile(r"%\(([^)]+?)\)s")
+    _positional_pattern = re.compile(
+        f"{_pyformat_pattern.pattern}|{_post_compile_pattern.pattern}"
+    )
 
     def __init__(
         self,
@@ -1055,10 +1109,15 @@ class SQLCompiler(Compiled):
         # true if the paramstyle is positional
         self.positional = dialect.positional
         if self.positional:
-            self.positiontup_level = {}
-            self.positiontup = []
-            self._numeric_binds = dialect.paramstyle == "numeric"
-        self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
+            self._numeric_binds = nb = dialect.paramstyle.startswith("numeric")
+            if nb:
+                self._numeric_binds_identifier_char = (
+                    "$" if dialect.paramstyle == "numeric_dollar" else ":"
+                )
+
+            self.compilation_bindtemplate = _pyformat_template
+        else:
+            self.compilation_bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
 
         self.ctes = None
 
@@ -1095,11 +1154,17 @@ class SQLCompiler(Compiled):
                 ):
                     self.inline = True
 
-        if self.positional and self._numeric_binds:
-            self._apply_numbered_params()
+        self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
+
+        if self.state is CompilerState.STRING_APPLIED:
+            if self.positional:
+                if self._numeric_binds:
+                    self._process_numeric()
+                else:
+                    self._process_positional()
 
-        if self._render_postcompile:
-            self._process_parameters_for_postcompile(_populate_self=True)
+            if self._render_postcompile:
+                self._process_parameters_for_postcompile(_populate_self=True)
 
     @property
     def insert_single_values_expr(self) -> Optional[str]:
@@ -1135,7 +1200,7 @@ class SQLCompiler(Compiled):
         """
         if self.implicit_returning:
             return self.implicit_returning
-        elif is_dml(self.statement):
+        elif self.statement is not None and is_dml(self.statement):
             return [
                 c
                 for c in self.statement._all_selected_columns
@@ -1217,10 +1282,6 @@ class SQLCompiler(Compiled):
         self.level_name_by_cte = {}
 
         self.ctes_recursive = False
-        if self.positional:
-            self.cte_positional = {}
-            self.cte_level = {}
-            self.cte_order = collections.defaultdict(list)
 
         return ctes
 
@@ -1248,12 +1309,145 @@ class SQLCompiler(Compiled):
                 ordered_columns,
             )
 
-    def _apply_numbered_params(self):
-        poscount = itertools.count(1)
+    def _process_positional(self):
+        assert not self.positiontup
+        assert self.state is CompilerState.STRING_APPLIED
+        assert not self._numeric_binds
+
+        if self.dialect.paramstyle == "format":
+            placeholder = "%s"
+        else:
+            assert self.dialect.paramstyle == "qmark"
+            placeholder = "?"
+
+        positions = []
+
+        def find_position(m: re.Match[str]) -> str:
+            normal_bind = m.group(1)
+            if normal_bind:
+                positions.append(normal_bind)
+                return placeholder
+            else:
+                # this a post-compile bind
+                positions.append(m.group(2))
+                return m.group(0)
+
         self.string = re.sub(
-            r"\[_POSITION\]", lambda m: str(next(poscount)), self.string
+            self._positional_pattern, find_position, self.string
         )
 
+        if self.escaped_bind_names:
+            reverse_escape = {v: k for k, v in self.escaped_bind_names.items()}
+            assert len(self.escaped_bind_names) == len(reverse_escape)
+            self.positiontup = [
+                reverse_escape.get(name, name) for name in positions
+            ]
+        else:
+            self.positiontup = positions
+
+        if self._insertmanyvalues:
+            positions = []
+            single_values_expr = re.sub(
+                self._positional_pattern,
+                find_position,
+                self._insertmanyvalues.single_values_expr,
+            )
+            insert_crud_params = [
+                (
+                    v[0],
+                    v[1],
+                    re.sub(self._positional_pattern, find_position, v[2]),
+                    v[3],
+                )
+                for v in self._insertmanyvalues.insert_crud_params
+            ]
+
+            self._insertmanyvalues = _InsertManyValues(
+                is_default_expr=self._insertmanyvalues.is_default_expr,
+                single_values_expr=single_values_expr,
+                insert_crud_params=insert_crud_params,
+                num_positional_params_counted=(
+                    self._insertmanyvalues.num_positional_params_counted
+                ),
+            )
+
+    def _process_numeric(self):
+        assert self._numeric_binds
+        assert self.state is CompilerState.STRING_APPLIED
+
+        num = 1
+        param_pos: Dict[str, str] = {}
+        order: Iterable[str]
+        if self._insertmanyvalues and self._values_bindparam is not None:
+            # bindparams that are not in values are always placed first.
+            # this avoids the need of changing them when using executemany
+            # values () ()
+            order = itertools.chain(
+                (
+                    name
+                    for name in self.bind_names.values()
+                    if name not in self._values_bindparam
+                ),
+                self.bind_names.values(),
+            )
+        else:
+            order = self.bind_names.values()
+
+        for bind_name in order:
+            if bind_name in param_pos:
+                continue
+            bind = self.binds[bind_name]
+            if (
+                bind in self.post_compile_params
+                or bind in self.literal_execute_params
+            ):
+                # set to None to just mark the in positiontup, it will not
+                # be replaced below.
+                param_pos[bind_name] = None  # type: ignore
+            else:
+                ph = f"{self._numeric_binds_identifier_char}{num}"
+                num += 1
+                param_pos[bind_name] = ph
+
+        self.next_numeric_pos = num
+
+        self.positiontup = list(param_pos)
+        if self.escaped_bind_names:
+            reverse_escape = {v: k for k, v in self.escaped_bind_names.items()}
+            assert len(self.escaped_bind_names) == len(reverse_escape)
+            param_pos = {
+                self.escaped_bind_names.get(name, name): pos
+                for name, pos in param_pos.items()
+            }
+
+        # Can't use format here since % chars are not escaped.
+        self.string = self._pyformat_pattern.sub(
+            lambda m: param_pos[m.group(1)], self.string
+        )
+
+        if self._insertmanyvalues:
+            single_values_expr = (
+                # format is ok here since single_values_expr includes only
+                # place-holders
+                self._insertmanyvalues.single_values_expr
+                % param_pos
+            )
+            insert_crud_params = [
+                (v[0], v[1], "%s", v[3])
+                for v in self._insertmanyvalues.insert_crud_params
+            ]
+
+            self._insertmanyvalues = _InsertManyValues(
+                is_default_expr=self._insertmanyvalues.is_default_expr,
+                # This has the numbers (:1, :2)
+                single_values_expr=single_values_expr,
+                # The single binds are instead %s so they can be formatted
+                insert_crud_params=insert_crud_params,
+                num_positional_params_counted=(
+                    self._insertmanyvalues.num_positional_params_counted
+                ),
+            )
+
     @util.memoized_property
     def _bind_processors(
         self,
@@ -1492,39 +1686,30 @@ class SQLCompiler(Compiled):
 
         new_processors: Dict[str, _BindProcessorType[Any]] = {}
 
-        if self.positional and self._numeric_binds:
-            # I'm not familiar with any DBAPI that uses 'numeric'.
-            # strategy would likely be to make use of numbers greater than
-            # the highest number present; then for expanding parameters,
-            # append them to the end of the parameter list.   that way
-            # we avoid having to renumber all the existing parameters.
-            raise NotImplementedError(
-                "'post-compile' bind parameters are not supported with "
-                "the 'numeric' paramstyle at this time."
-            )
-
         replacement_expressions: Dict[str, Any] = {}
         to_update_sets: Dict[str, Any] = {}
 
         # notes:
         # *unescaped* parameter names in:
-        # self.bind_names, self.binds, self._bind_processors
+        # self.bind_names, self.binds, self._bind_processors, self.positiontup
         #
         # *escaped* parameter names in:
         # construct_params(), replacement_expressions
 
+        numeric_positiontup: Optional[List[str]] = None
+
         if self.positional and self.positiontup is not None:
             names: Iterable[str] = self.positiontup
+            if self._numeric_binds:
+                numeric_positiontup = []
         else:
             names = self.bind_names.values()
 
+        ebn = self.escaped_bind_names
         for name in names:
-            escaped_name = (
-                self.escaped_bind_names.get(name, name)
-                if self.escaped_bind_names
-                else name
-            )
+            escaped_name = ebn.get(name, name) if ebn else name
             parameter = self.binds[name]
+
             if parameter in self.literal_execute_params:
                 if escaped_name not in replacement_expressions:
                     value = parameters.pop(escaped_name)
@@ -1555,10 +1740,10 @@ class SQLCompiler(Compiled):
                     # in the escaped_bind_names dictionary.
                     values = parameters.pop(name)
 
-                    leep = self._literal_execute_expanding_parameter
-                    to_update, replacement_expr = leep(
+                    leep_res = self._literal_execute_expanding_parameter(
                         escaped_name, parameter, values
                     )
+                    (to_update, replacement_expr) = leep_res
 
                     to_update_sets[escaped_name] = to_update
                     replacement_expressions[escaped_name] = replacement_expr
@@ -1583,7 +1768,14 @@ class SQLCompiler(Compiled):
                             for key, _ in to_update
                             if name in single_processors
                         )
-                    if positiontup is not None:
+                    if numeric_positiontup is not None:
+                        numeric_positiontup.extend(
+                            name for name, _ in to_update
+                        )
+                    elif positiontup is not None:
+                        # to_update has escaped names, but that's ok since
+                        # these are new names, that aren't in the
+                        # escaped_bind_names dict.
                         positiontup.extend(name for name, _ in to_update)
                     expanded_parameters[name] = [
                         expand_key for expand_key, _ in to_update
@@ -1607,11 +1799,23 @@ class SQLCompiler(Compiled):
             return expr
 
         statement = re.sub(
-            r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]",
-            process_expanding,
-            self.string,
+            self._post_compile_pattern, process_expanding, self.string
         )
 
+        if numeric_positiontup is not None:
+            assert positiontup is not None
+            param_pos = {
+                key: f"{self._numeric_binds_identifier_char}{num}"
+                for num, key in enumerate(
+                    numeric_positiontup, self.next_numeric_pos
+                )
+            }
+            # Can't use format here since % chars are not escaped.
+            statement = self._pyformat_pattern.sub(
+                lambda m: param_pos[m.group(1)], statement
+            )
+            positiontup.extend(numeric_positiontup)
+
         expanded_state = ExpandedState(
             statement,
             parameters,
@@ -2109,13 +2313,7 @@ class SQLCompiler(Compiled):
         text = self.process(taf.element, **kw)
         if self.ctes:
             nesting_level = len(self.stack) if not toplevel else None
-            text = (
-                self._render_cte_clause(
-                    nesting_level=nesting_level,
-                    visiting_cte=kw.get("visiting_cte"),
-                )
-                + text
-            )
+            text = self._render_cte_clause(nesting_level=nesting_level) + text
 
         self.stack.pop(-1)
 
@@ -2411,7 +2609,6 @@ class SQLCompiler(Compiled):
                 self._render_cte_clause(
                     nesting_level=nesting_level,
                     include_following_stack=True,
-                    visiting_cte=kwargs.get("visiting_cte"),
                 )
                 + text
             )
@@ -2625,6 +2822,11 @@ class SQLCompiler(Compiled):
         dialect = self.dialect
         typ_dialect_impl = parameter.type._unwrapped_dialect_impl(dialect)
 
+        if self._numeric_binds:
+            bind_template = self.compilation_bindtemplate
+        else:
+            bind_template = self.bindtemplate
+
         if (
             self.dialect._bind_typing_render_casts
             and typ_dialect_impl.render_bind_cast
@@ -2634,13 +2836,13 @@ class SQLCompiler(Compiled):
                 return self.render_bind_cast(
                     parameter.type,
                     typ_dialect_impl,
-                    self.bindtemplate % {"name": name},
+                    bind_template % {"name": name},
                 )
 
         else:
 
             def _render_bindtemplate(name):
-                return self.bindtemplate % {"name": name}
+                return bind_template % {"name": name}
 
         if not values:
             to_update = []
@@ -3224,7 +3426,6 @@ class SQLCompiler(Compiled):
     def bindparam_string(
         self,
         name: str,
-        positional_names: Optional[List[str]] = None,
         post_compile: bool = False,
         expanding: bool = False,
         escaped_from: Optional[str] = None,
@@ -3232,12 +3433,9 @@ class SQLCompiler(Compiled):
         **kw: Any,
     ) -> str:
 
-        if self.positional:
-            if positional_names is not None:
-                positional_names.append(name)
-            else:
-                self.positiontup.append(name)  # type: ignore[union-attr]
-            self.positiontup_level[name] = len(self.stack)  # type: ignore[index] # noqa: E501
+        if self._visited_bindparam is not None:
+            self._visited_bindparam.append(name)
+
         if not escaped_from:
 
             if _BIND_TRANSLATE_RE.search(name):
@@ -3271,6 +3469,8 @@ class SQLCompiler(Compiled):
                 if type_impl.render_literal_cast:
                     ret = self.render_bind_cast(bindparam_type, type_impl, ret)
             return ret
+        elif self.state is CompilerState.COMPILING:
+            ret = self.compilation_bindtemplate % {"name": name}
         else:
             ret = self.bindtemplate % {"name": name}
 
@@ -3349,8 +3549,6 @@ class SQLCompiler(Compiled):
                 self.level_name_by_cte[_reference_cte] = new_level_name + (
                     cte_opts,
                 )
-                if self.positional:
-                    self.cte_level[cte] = cte_level
 
         else:
             cte_level = len(self.stack) if nesting else 1
@@ -3414,8 +3612,6 @@ class SQLCompiler(Compiled):
             self.level_name_by_cte[_reference_cte] = cte_level_name + (
                 cte_opts,
             )
-            if self.positional:
-                self.cte_level[cte] = cte_level
 
             if pre_alias_cte not in self.ctes:
                 self.visit_cte(pre_alias_cte, **kwargs)
@@ -3455,9 +3651,6 @@ class SQLCompiler(Compiled):
                         )
                     )
 
-                if self.positional:
-                    kwargs["positional_names"] = self.cte_positional[cte] = []
-
                 assert kwargs.get("subquery", False) is False
 
                 if not self.stack:
@@ -4152,13 +4345,7 @@ class SQLCompiler(Compiled):
         # In compound query, CTEs are shared at the compound level
         if self.ctes and (not is_embedded_select or toplevel):
             nesting_level = len(self.stack) if not toplevel else None
-            text = (
-                self._render_cte_clause(
-                    nesting_level=nesting_level,
-                    visiting_cte=kwargs.get("visiting_cte"),
-                )
-                + text
-            )
+            text = self._render_cte_clause(nesting_level=nesting_level) + text
 
         if select_stmt._suffixes:
             text += " " + self._generate_prefixes(
@@ -4332,7 +4519,6 @@ class SQLCompiler(Compiled):
         self,
         nesting_level=None,
         include_following_stack=False,
-        visiting_cte=None,
     ):
         """
         include_following_stack
@@ -4367,46 +4553,6 @@ class SQLCompiler(Compiled):
             return ""
         ctes_recursive = any([cte.recursive for cte in ctes])
 
-        if self.positional:
-            self.cte_order[visiting_cte].extend(ctes)
-
-            if visiting_cte is None and self.cte_order:
-                assert self.positiontup is not None
-
-                def get_nested_positional(cte):
-                    if cte in self.cte_order:
-                        children = self.cte_order.pop(cte)
-                        to_add = list(
-                            itertools.chain.from_iterable(
-                                get_nested_positional(child_cte)
-                                for child_cte in children
-                            )
-                        )
-                        if cte in self.cte_positional:
-                            return reorder_positional(
-                                self.cte_positional[cte],
-                                to_add,
-                                self.cte_level[children[0]],
-                            )
-                        else:
-                            return to_add
-                    else:
-                        return self.cte_positional.get(cte, [])
-
-                def reorder_positional(pos, to_add, level):
-                    if not level:
-                        return to_add + pos
-                    index = 0
-                    for index, name in enumerate(reversed(pos)):
-                        if self.positiontup_level[name] < level:  # type: ignore[index] # noqa: E501
-                            break
-                    return pos[:-index] + to_add + pos[-index:]
-
-                to_add = get_nested_positional(None)
-                self.positiontup = reorder_positional(
-                    self.positiontup, to_add, nesting_level
-                )
-
         cte_text = self.get_cte_preamble(ctes_recursive) + " "
         cte_text += ", \n".join([txt for txt in ctes.values()])
         cte_text += "\n "
@@ -4762,6 +4908,11 @@ class SQLCompiler(Compiled):
             keys_to_replace = set()
             base_parameters = {}
             executemany_values_w_comma = f"({imv.single_values_expr}), "
+            if self._numeric_binds:
+                escaped = re.escape(self._numeric_binds_identifier_char)
+                executemany_values_w_comma = re.sub(
+                    rf"{escaped}\d+", "%s", executemany_values_w_comma
+                )
 
         while batches:
             batch = batches[0:batch_size]
@@ -4794,25 +4945,37 @@ class SQLCompiler(Compiled):
 
                 num_ins_params = imv.num_positional_params_counted
 
+                batch_iterator: Iterable[Tuple[Any, ...]]
                 if num_ins_params == len(batch[0]):
                     extra_params = ()
-                    batch_iterator: Iterable[Tuple[Any, ...]] = batch
-                elif self.returning_precedes_values:
+                    batch_iterator = batch
+                elif self.returning_precedes_values or self._numeric_binds:
                     extra_params = batch[0][:-num_ins_params]
                     batch_iterator = (b[-num_ins_params:] for b in batch)
                 else:
                     extra_params = batch[0][num_ins_params:]
                     batch_iterator = (b[:num_ins_params] for b in batch)
 
+                values_string = (executemany_values_w_comma * len(batch))[:-2]
+                if self._numeric_binds and num_ins_params > 0:
+                    # need to format here, since statement may contain
+                    # unescaped %, while values_string contains just (%s, %s)
+                    start = len(extra_params) + 1
+                    end = num_ins_params * len(batch) + start
+                    positions = tuple(
+                        f"{self._numeric_binds_identifier_char}{i}"
+                        for i in range(start, end)
+                    )
+                    values_string = values_string % positions
+
                 replaced_statement = statement.replace(
-                    "__EXECMANY_TOKEN__",
-                    (executemany_values_w_comma * len(batch))[:-2],
+                    "__EXECMANY_TOKEN__", values_string
                 )
 
                 replaced_parameters = tuple(
                     itertools.chain.from_iterable(batch_iterator)
                 )
-                if self.returning_precedes_values:
+                if self.returning_precedes_values or self._numeric_binds:
                     replaced_parameters = extra_params + replaced_parameters
                 else:
                     replaced_parameters = replaced_parameters + extra_params
@@ -4869,23 +5032,30 @@ class SQLCompiler(Compiled):
             }
         )
 
-        positiontup_before = positiontup_after = 0
+        counted_bindparam = 0
 
         # for positional, insertmanyvalues needs to know how many
         # bound parameters are in the VALUES sequence; there's no simple
         # rule because default expressions etc. can have zero or more
         # params inside them.   After multiple attempts to figure this out,
-        # this very simplistic "count before, then count after" works and is
+        # this very simplistic "count after" works and is
         # likely the least amount of callcounts, though looks clumsy
-        if self.positiontup:
-            positiontup_before = len(self.positiontup)
+        if self.positional:
+            self._visited_bindparam = []
 
         crud_params_struct = crud._get_crud_params(
             self, insert_stmt, compile_state, toplevel, **kw
         )
 
-        if self.positiontup:
-            positiontup_after = len(self.positiontup)
+        if self.positional:
+            assert self._visited_bindparam is not None
+            counted_bindparam = len(self._visited_bindparam)
+            if self._numeric_binds:
+                if self._values_bindparam is not None:
+                    self._values_bindparam += self._visited_bindparam
+                else:
+                    self._values_bindparam = self._visited_bindparam
+            self._visited_bindparam = None
 
         crud_params_single = crud_params_struct.single_params
 
@@ -4940,31 +5110,13 @@ class SQLCompiler(Compiled):
 
         if self.implicit_returning or insert_stmt._returning:
 
-            # if returning clause is rendered first, capture bound parameters
-            # while visiting and place them prior to the VALUES oriented
-            # bound parameters, when using positional parameter scheme
-            rpv = self.returning_precedes_values
-            flip_pt = rpv and self.positional
-            if flip_pt:
-                pt: Optional[List[str]] = self.positiontup
-                temp_pt: Optional[List[str]]
-                self.positiontup = temp_pt = []
-            else:
-                temp_pt = pt = None
-
             returning_clause = self.returning_clause(
                 insert_stmt,
                 self.implicit_returning or insert_stmt._returning,
                 populate_result_map=toplevel,
             )
 
-            if flip_pt:
-                if TYPE_CHECKING:
-                    assert temp_pt is not None
-                    assert pt is not None
-                self.positiontup = temp_pt + pt
-
-            if rpv:
+            if self.returning_precedes_values:
                 text += " " + returning_clause
 
         else:
@@ -4982,7 +5134,6 @@ class SQLCompiler(Compiled):
                     self._render_cte_clause(
                         nesting_level=nesting_level,
                         include_following_stack=True,
-                        visiting_cte=kw.get("visiting_cte"),
                     ),
                     select_text,
                 )
@@ -4999,7 +5150,7 @@ class SQLCompiler(Compiled):
                     cast(
                         "List[crud._CrudParamElementStr]", crud_params_single
                     ),
-                    (positiontup_after - positiontup_before),
+                    counted_bindparam,
                 )
         elif compile_state._has_multi_parameters:
             text += " VALUES %s" % (
@@ -5033,7 +5184,7 @@ class SQLCompiler(Compiled):
                         "List[crud._CrudParamElementStr]",
                         crud_params_single,
                     ),
-                    positiontup_after - positiontup_before,
+                    counted_bindparam,
                 )
 
         if insert_stmt._post_values_clause is not None:
@@ -5052,7 +5203,6 @@ class SQLCompiler(Compiled):
                 self._render_cte_clause(
                     nesting_level=nesting_level,
                     include_following_stack=True,
-                    visiting_cte=kw.get("visiting_cte"),
                 )
                 + text
             )
@@ -5201,13 +5351,7 @@ class SQLCompiler(Compiled):
 
         if self.ctes:
             nesting_level = len(self.stack) if not toplevel else None
-            text = (
-                self._render_cte_clause(
-                    nesting_level=nesting_level,
-                    visiting_cte=kw.get("visiting_cte"),
-                )
-                + text
-            )
+            text = self._render_cte_clause(nesting_level=nesting_level) + text
 
         self.stack.pop(-1)
 
@@ -5321,13 +5465,7 @@ class SQLCompiler(Compiled):
 
         if self.ctes:
             nesting_level = len(self.stack) if not toplevel else None
-            text = (
-                self._render_cte_clause(
-                    nesting_level=nesting_level,
-                    visiting_cte=kw.get("visiting_cte"),
-                )
-                + text
-            )
+            text = self._render_cte_clause(nesting_level=nesting_level) + text
 
         self.stack.pop(-1)
 
index 017ff7baa0b09c52c66e20dbed3462a888fbe1d3..ae1b032ae16079e3b0d63699c3aab12deca66720 100644 (file)
@@ -85,8 +85,8 @@ _CrudParamElement = Tuple[
 ]
 _CrudParamElementStr = Tuple[
     "KeyedColumnElement[Any]",
-    str,
-    str,
+    str,  # column name
+    str,  # placeholder
     Iterable[str],
 ]
 _CrudParamElementSQLExpr = Tuple[
index d183372c34f74ae89ecb11e404089d5c9e44f98b..45a2496dd7583e5ae3310de1695f0da2b9926a0f 100644 (file)
@@ -11,6 +11,7 @@ from __future__ import annotations
 
 import collections
 import contextlib
+import itertools
 import re
 
 from .. import event
@@ -285,7 +286,8 @@ class DialectSQL(CompiledSQL):
 
         return received_stmt, execute_observed.context.compiled_parameters
 
-    def _dialect_adjusted_statement(self, paramstyle):
+    def _dialect_adjusted_statement(self, dialect):
+        paramstyle = dialect.paramstyle
         stmt = re.sub(r"[\n\t]", "", self.statement)
 
         # temporarily escape out PG double colons
@@ -300,8 +302,14 @@ class DialectSQL(CompiledSQL):
                 repl = "?"
             elif paramstyle == "format":
                 repl = r"%s"
-            elif paramstyle == "numeric":
-                repl = None
+            elif paramstyle.startswith("numeric"):
+                counter = itertools.count(1)
+
+                num_identifier = "$" if paramstyle == "numeric_dollar" else ":"
+
+                def repl(m):
+                    return f"{num_identifier}{next(counter)}"
+
             stmt = re.sub(r":([\w_]+)", repl, stmt)
 
         # put them back
@@ -310,20 +318,20 @@ class DialectSQL(CompiledSQL):
         return stmt
 
     def _compare_sql(self, execute_observed, received_statement):
-        paramstyle = execute_observed.context.dialect.paramstyle
-        stmt = self._dialect_adjusted_statement(paramstyle)
+        stmt = self._dialect_adjusted_statement(
+            execute_observed.context.dialect
+        )
         return received_statement == stmt
 
     def _failure_message(self, execute_observed, expected_params):
-        paramstyle = execute_observed.context.dialect.paramstyle
         return (
             "Testing for compiled statement\n%r partial params %s, "
             "received\n%%(received_statement)r with params "
             "%%(received_parameters)r"
             % (
-                self._dialect_adjusted_statement(paramstyle).replace(
-                    "%", "%%"
-                ),
+                self._dialect_adjusted_statement(
+                    execute_observed.context.dialect
+                ).replace("%", "%%"),
                 repr(expected_params).replace("%", "%%"),
             )
         )
index 9578765798402a7b414b8ddba64d0b21b5486a37..6adcf5b640afdc4130c5b8f60db26a25eb09dcff 100644 (file)
@@ -189,7 +189,7 @@ def variation(argname, cases):
             elif querytyp.legacy_query:
                 stmt = Session.query(Thing)
             else:
-                assert False
+                querytyp.fail()
 
 
     The variable provided is a slots object of boolean variables, as well
index 656a4e98a33b2d9e64e756a3ac50babfc86fbf98..ffe0f453af1f3d98ab91232a60c8b8dbc992a423 100644 (file)
@@ -371,6 +371,22 @@ def _setup_options(opt, file_config):
     options = opt
 
 
+@pre
+def _register_sqlite_numeric_dialect(opt, file_config):
+    from sqlalchemy.dialects import registry
+
+    registry.register(
+        "sqlite.pysqlite_numeric",
+        "sqlalchemy.dialects.sqlite.pysqlite",
+        "_SQLiteDialect_pysqlite_numeric",
+    )
+    registry.register(
+        "sqlite.pysqlite_dollar",
+        "sqlalchemy.dialects.sqlite.pysqlite",
+        "_SQLiteDialect_pysqlite_dollar",
+    )
+
+
 @post
 def __ensure_cext(opt, file_config):
     if os.environ.get("REQUIRE_SQLALCHEMY_CEXT", "0") == "1":
index b02ad2682d9f3f6861b6b4cc09a19ba7b2b958bf..485f1d6820980e252aa1cff8c0f7ba382256f396 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -144,6 +144,8 @@ oracle_db_link2 = test_link2
 [db]
 default = sqlite:///:memory:
 sqlite = sqlite:///:memory:
+sqlite_numeric = sqlite+pysqlite_numeric:///:memory:
+sqlite_dollar = sqlite+pysqlite_dollar:///:memory:
 aiosqlite = sqlite+aiosqlite:///:memory:
 sqlite_file = sqlite:///querytest.db
 aiosqlite_file = sqlite+aiosqlite:///async_querytest.db
index 42ec20743d029747abe1a75c70341eed5cfb2ca5..2b32d6db7fbcdcbe9d51356bb8d052def2196963 100644 (file)
@@ -990,19 +990,19 @@ class MatchTest(fixtures.TablesTest, AssertsCompiledSQL):
                 "matchtable.title @@ plainto_tsquery(%(title_1)s)",
             )
 
-    @testing.requires.format_paramstyle
+    @testing.only_if("+asyncpg")
     def test_expression_positional(self, connection):
         matchtable = self.tables.matchtable
 
         if self._strs_render_bind_casts(connection):
             self.assert_compile(
                 matchtable.c.title.match("somstr"),
-                "matchtable.title @@ plainto_tsquery(%s::VARCHAR(200))",
+                "matchtable.title @@ plainto_tsquery($1::VARCHAR(200))",
             )
         else:
             self.assert_compile(
                 matchtable.c.title.match("somstr"),
-                "matchtable.title @@ plainto_tsquery(%s)",
+                "matchtable.title @@ plainto_tsquery($1)",
             )
 
     def test_simple_match(self, connection):
index c5147e37fb7b31f4e6aae1227bfbdc14b362c592..07117b862ef301f8efa30fd4e67dfbf788325d28 100644 (file)
@@ -2916,6 +2916,8 @@ class OnConflictTest(AssertsCompiledSQL, fixtures.TablesTest):
         )
 
     @testing.combinations("control", "excluded", "dict")
+    @testing.skip_if("+pysqlite_numeric")
+    @testing.skip_if("+pysqlite_dollar")
     def test_set_excluded(self, scenario):
         """test #8014, sending all of .excluded to set"""
 
index 277248617b446747117fa7fc1a7f756990a8ca5c..19c26f43c74cec51f79634a5bc5dd1a456b6a788 100644 (file)
@@ -28,7 +28,7 @@ def exec_sql(engine, sql, *args, **kwargs):
 
 
 class LogParamsTest(fixtures.TestBase):
-    __only_on__ = "sqlite"
+    __only_on__ = "sqlite+pysqlite"
     __requires__ = ("ad_hoc_engines",)
 
     def setup_test(self):
@@ -704,7 +704,7 @@ class LoggingNameTest(fixtures.TestBase):
 
 
 class TransactionContextLoggingTest(fixtures.TestBase):
-    __only_on__ = "sqlite"
+    __only_on__ = "sqlite+pysqlite"
 
     @testing.fixture()
     def plain_assert_buf(self, plain_logging_engine):
index 557b5e9da4a417daefcb60c23b02f5ae8f105fa5..78607e03d80354630534ef742ff0af297d6e0a32 100644 (file)
@@ -958,7 +958,7 @@ class BulkDMLReturningJoinedInhTest(
     BulkDMLReturningInhTest, fixtures.DeclarativeMappedTest
 ):
 
-    __requires__ = ("insert_returning",)
+    __requires__ = ("insert_returning", "insert_executemany_returning")
     __backend__ = True
 
     @classmethod
@@ -1044,7 +1044,7 @@ class BulkDMLReturningJoinedInhTest(
 class BulkDMLReturningSingleInhTest(
     BulkDMLReturningInhTest, fixtures.DeclarativeMappedTest
 ):
-    __requires__ = ("insert_returning",)
+    __requires__ = ("insert_returning", "insert_executemany_returning")
     __backend__ = True
 
     @classmethod
@@ -1075,7 +1075,7 @@ class BulkDMLReturningSingleInhTest(
 class BulkDMLReturningConcreteInhTest(
     BulkDMLReturningInhTest, fixtures.DeclarativeMappedTest
 ):
-    __requires__ = ("insert_returning",)
+    __requires__ = ("insert_returning", "insert_executemany_returning")
     __backend__ = True
 
     @classmethod
index df335f0f6726849f42c509cc36ec67f8c4bcbb01..714878f4ea144d14b02fd026c54e53298fe07cb7 100644 (file)
@@ -1791,14 +1791,33 @@ class WriteOnlyBulkTest(
                         "INSERT INTO users (name) VALUES (:name)",
                         [{"name": "x"}],
                     ),
-                    CompiledSQL(
-                        "INSERT INTO addresses (user_id, email_address) "
-                        "VALUES (:user_id, :email_address) "
-                        "RETURNING addresses.id",
+                    Conditional(
+                        testing.requires.insert_executemany_returning.enabled,
+                        [
+                            CompiledSQL(
+                                "INSERT INTO addresses "
+                                "(user_id, email_address) "
+                                "VALUES (:user_id, :email_address) "
+                                "RETURNING addresses.id",
+                                [
+                                    {"user_id": uid, "email_address": "e1"},
+                                    {"user_id": uid, "email_address": "e2"},
+                                    {"user_id": uid, "email_address": "e3"},
+                                ],
+                            )
+                        ],
                         [
-                            {"user_id": uid, "email_address": "e1"},
-                            {"user_id": uid, "email_address": "e2"},
-                            {"user_id": uid, "email_address": "e3"},
+                            CompiledSQL(
+                                "INSERT INTO addresses "
+                                "(user_id, email_address) "
+                                "VALUES (:user_id, :email_address)",
+                                param,
+                            )
+                            for param in [
+                                {"user_id": uid, "email_address": "e1"},
+                                {"user_id": uid, "email_address": "e2"},
+                                {"user_id": uid, "email_address": "e3"},
+                            ]
                         ],
                     ),
                 ],
@@ -1863,14 +1882,33 @@ class WriteOnlyBulkTest(
                         "INSERT INTO users (name) VALUES (:name)",
                         [{"name": "x"}],
                     ),
-                    CompiledSQL(
-                        "INSERT INTO addresses (user_id, email_address) "
-                        "VALUES (:user_id, :email_address) "
-                        "RETURNING addresses.id",
+                    Conditional(
+                        testing.requires.insert_executemany_returning.enabled,
+                        [
+                            CompiledSQL(
+                                "INSERT INTO addresses "
+                                "(user_id, email_address) "
+                                "VALUES (:user_id, :email_address) "
+                                "RETURNING addresses.id",
+                                [
+                                    {"user_id": uid, "email_address": "e1"},
+                                    {"user_id": uid, "email_address": "e2"},
+                                    {"user_id": uid, "email_address": "e3"},
+                                ],
+                            )
+                        ],
                         [
-                            {"user_id": uid, "email_address": "e1"},
-                            {"user_id": uid, "email_address": "e2"},
-                            {"user_id": uid, "email_address": "e3"},
+                            CompiledSQL(
+                                "INSERT INTO addresses "
+                                "(user_id, email_address) "
+                                "VALUES (:user_id, :email_address)",
+                                param,
+                            )
+                            for param in [
+                                {"user_id": uid, "email_address": "e1"},
+                                {"user_id": uid, "email_address": "e2"},
+                                {"user_id": uid, "email_address": "e3"},
+                            ]
                         ],
                     ),
                 ],
index 36c47e27be63508cf9d02c3540aaa84e766712ec..eb5a795e22fc03a1c753a809e3de28ac3e464ecc 100644 (file)
@@ -1458,7 +1458,7 @@ class MergeTest(_fixtures.FixtureTest):
             )
             attrname = "user"
         else:
-            assert False
+            direction.fail()
 
         assert attrname in obj_to_merge.__dict__
 
index f204e954cabec6b69647611d3880e7d5a647d6c1..468d43063db418a2cc1c32a8ace5fdb9f2e8f954 100644 (file)
@@ -3077,7 +3077,7 @@ class EagerDefaultsTest(fixtures.MappedTest):
 
         asserter.assert_(
             Conditional(
-                testing.db.dialect.insert_executemany_returning,
+                testing.db.dialect.insert_returning,
                 [
                     CompiledSQL(
                         "INSERT INTO test (id) VALUES (:id) "
index 5276593c93b712d0bb191edc686762c0dab60acf..83cd65cd890aa56d3e1ea3878450f4687348e74c 100644 (file)
@@ -232,7 +232,6 @@ class DefaultRequirements(SuiteRequirements):
                 "mariadb+pymysql",
                 "mariadb+cymysql",
                 "mariadb+mysqlconnector",
-                "postgresql+asyncpg",
                 "postgresql+pg8000",
             ]
         )
@@ -387,6 +386,14 @@ class DefaultRequirements(SuiteRequirements):
             ]
         )
 
+    @property
+    def predictable_gc(self):
+        """target platform must remove all cycles unconditionally when
+        gc.collect() is called, as well as clean out unreferenced subclasses.
+
+        """
+        return self.cpython + skip_if("+aiosqlite")
+
     @property
     def memory_process_intensive(self):
         """Driver is able to handle the memory tests which run in a subprocess
@@ -969,6 +976,8 @@ class DefaultRequirements(SuiteRequirements):
             "mariadb",
             "sqlite+aiosqlite",
             "sqlite+pysqlite",
+            "sqlite+pysqlite_numeric",
+            "sqlite+pysqlite_dollar",
             "sqlite+pysqlcipher",
             "mssql",
         )
index 205ce515762d8d27b31ea495e6ae5dee5f66caf4..d342b924838561931ebf2fff9a308c38c9a60407 100644 (file)
@@ -79,6 +79,7 @@ from sqlalchemy.sql import util as sql_util
 from sqlalchemy.sql.elements import BooleanClauseList
 from sqlalchemy.sql.elements import ColumnElement
 from sqlalchemy.sql.elements import CompilerColumnElement
+from sqlalchemy.sql.elements import Grouping
 from sqlalchemy.sql.expression import ClauseElement
 from sqlalchemy.sql.expression import ClauseList
 from sqlalchemy.sql.selectable import LABEL_STYLE_NONE
@@ -4915,88 +4916,259 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase):
                 dialect="default",
             )
 
-    @standalone_escape
-    @testing.variation("use_assert_compile", [True, False])
     @testing.variation("use_positional", [True, False])
-    def test_standalone_bindparam_escape_expanding(
-        self, paramname, expected, use_assert_compile, use_positional
+    def test_standalone_bindparam_escape_collision(self, use_positional):
+        """this case is currently not supported
+
+        it's kinda bad since positional takes the unescaped param
+        while non positional takes the escaped one.
+        """
+        stmt = select(table1.c.myid).where(
+            table1.c.name == bindparam("[brackets]", value="x"),
+            table1.c.description == bindparam("_brackets_", value="y"),
+        )
+
+        if use_positional:
+            self.assert_compile(
+                stmt,
+                "SELECT mytable.myid FROM mytable WHERE mytable.name = ? "
+                "AND mytable.description = ?",
+                params={"[brackets]": "a", "_brackets_": "b"},
+                checkpositional=("a", "a"),
+                dialect="sqlite",
+            )
+        else:
+            self.assert_compile(
+                stmt,
+                "SELECT mytable.myid FROM mytable WHERE mytable.name = "
+                ":_brackets_ AND mytable.description = :_brackets_",
+                params={"[brackets]": "a", "_brackets_": "b"},
+                checkparams={"_brackets_": "b"},
+                dialect="default",
+            )
+
+    paramstyle = testing.variation("paramstyle", ["named", "qmark", "numeric"])
+
+    @standalone_escape
+    @paramstyle
+    def test_standalone_bindparam_escape_expanding_compile(
+        self, paramname, expected, paramstyle
     ):
         stmt = select(table1.c.myid).where(
             table1.c.name.in_(bindparam(paramname, value=["a", "b"]))
         )
 
-        if use_assert_compile:
-            if use_positional:
-                self.assert_compile(
-                    stmt,
-                    "SELECT mytable.myid FROM mytable "
-                    "WHERE mytable.name IN (?, ?)",
-                    params={paramname: ["y", "z"]},
-                    # NOTE: this is what render_postcompile will do right now
-                    # if you run construct_params().  render_postcompile mode
-                    # is not actually used by the execution internals, it's for
-                    # user-facing compilation code.  So this is likely a
-                    # current limitation of construct_params() which is not
-                    # doing the full blown postcompile; just assert that's
-                    # what it does for now.  it likely should be corrected
-                    # to make more sense.
-                    checkpositional=(["y", "z"], ["y", "z"]),
-                    dialect="sqlite",
-                    render_postcompile=True,
-                )
-            else:
-                self.assert_compile(
-                    stmt,
-                    "SELECT mytable.myid FROM mytable WHERE mytable.name IN "
-                    "(:%s_1, :%s_2)" % (expected, expected),
-                    params={paramname: ["y", "z"]},
-                    # NOTE: this is what render_postcompile will do right now
-                    # if you run construct_params().  render_postcompile mode
-                    # is not actually used by the execution internals, it's for
-                    # user-facing compilation code.  So this is likely a
-                    # current limitation of construct_params() which is not
-                    # doing the full blown postcompile; just assert that's
-                    # what it does for now.  it likely should be corrected
-                    # to make more sense.
-                    checkparams={
-                        "%s_1" % expected: ["y", "z"],
-                        "%s_2" % expected: ["y", "z"],
-                    },
-                    dialect="default",
-                    render_postcompile=True,
-                )
+        # NOTE: below the rendered params are just what
+        # render_postcompile will do right now
+        # if you run construct_params().  render_postcompile mode
+        # is not actually used by the execution internals, it's for
+        # user-facing compilation code.  So this is likely a
+        # current limitation of construct_params() which is not
+        # doing the full blown postcompile; just assert that's
+        # what it does for now.  it likely should be corrected
+        # to make more sense.
+        if paramstyle.qmark:
+            self.assert_compile(
+                stmt,
+                "SELECT mytable.myid FROM mytable "
+                "WHERE mytable.name IN (?, ?)",
+                params={paramname: ["y", "z"]},
+                checkpositional=(["y", "z"], ["y", "z"]),
+                dialect="sqlite",
+                render_postcompile=True,
+            )
+        elif paramstyle.numeric:
+            self.assert_compile(
+                stmt,
+                "SELECT mytable.myid FROM mytable "
+                "WHERE mytable.name IN (:1, :2)",
+                params={paramname: ["y", "z"]},
+                checkpositional=(["y", "z"], ["y", "z"]),
+                dialect=sqlite.dialect(paramstyle="numeric"),
+                render_postcompile=True,
+            )
+        elif paramstyle.named:
+            self.assert_compile(
+                stmt,
+                "SELECT mytable.myid FROM mytable WHERE mytable.name IN "
+                "(:%s_1, :%s_2)" % (expected, expected),
+                params={paramname: ["y", "z"]},
+                checkparams={
+                    "%s_1" % expected: ["y", "z"],
+                    "%s_2" % expected: ["y", "z"],
+                },
+                dialect="default",
+                render_postcompile=True,
+            )
         else:
-            # this is what DefaultDialect actually does.
-            # this should be matched to DefaultDialect._init_compiled()
-            if use_positional:
-                compiled = stmt.compile(
-                    dialect=default.DefaultDialect(paramstyle="qmark")
-                )
-            else:
-                compiled = stmt.compile(dialect=default.DefaultDialect())
+            paramstyle.fail()
 
-            checkparams = compiled.construct_params(
-                {paramname: ["y", "z"]}, escape_names=False
-            )
+    @standalone_escape
+    @paramstyle
+    def test_standalone_bindparam_escape_expanding(
+        self, paramname, expected, paramstyle
+    ):
+        stmt = select(table1.c.myid).where(
+            table1.c.name.in_(bindparam(paramname, value=["a", "b"]))
+        )
+        # this is what DefaultDialect actually does.
+        # this should be matched to DefaultDialect._init_compiled()
+        if paramstyle.qmark:
+            dialect = default.DefaultDialect(paramstyle="qmark")
+        elif paramstyle.numeric:
+            dialect = default.DefaultDialect(paramstyle="numeric")
+        else:
+            dialect = default.DefaultDialect()
 
-            # nothing actually happened.  if the compiler had
-            # render_postcompile set, the
-            # above weird param thing happens
-            eq_(checkparams, {paramname: ["y", "z"]})
+        compiled = stmt.compile(dialect=dialect)
+        checkparams = compiled.construct_params(
+            {paramname: ["y", "z"]}, escape_names=False
+        )
 
-            expanded_state = compiled._process_parameters_for_postcompile(
-                checkparams
-            )
+        # nothing actually happened.  if the compiler had
+        # render_postcompile set, the
+        # above weird param thing happens
+        eq_(checkparams, {paramname: ["y", "z"]})
+
+        expanded_state = compiled._process_parameters_for_postcompile(
+            checkparams
+        )
+        eq_(
+            expanded_state.additional_parameters,
+            {f"{expected}_1": "y", f"{expected}_2": "z"},
+        )
+
+        if paramstyle.qmark or paramstyle.numeric:
             eq_(
-                expanded_state.additional_parameters,
-                {f"{expected}_1": "y", f"{expected}_2": "z"},
+                expanded_state.positiontup,
+                [f"{expected}_1", f"{expected}_2"],
             )
 
-            if use_positional:
-                eq_(
-                    expanded_state.positiontup,
-                    [f"{expected}_1", f"{expected}_2"],
+    @paramstyle
+    def test_expanding_in_repeated(self, paramstyle):
+        stmt = (
+            select(table1)
+            .where(
+                table1.c.name.in_(
+                    bindparam("uname", value=["h", "e"], expanding=True)
+                )
+                | table1.c.name.in_(
+                    bindparam("uname2", value=["y"], expanding=True)
+                )
+            )
+            .where(table1.c.myid == 8)
+        )
+        stmt = stmt.union(
+            select(table1)
+            .where(
+                table1.c.name.in_(
+                    bindparam("uname", value=["h", "e"], expanding=True)
+                )
+                | table1.c.name.in_(
+                    bindparam("uname2", value=["y"], expanding=True)
                 )
+            )
+            .where(table1.c.myid == 9)
+        ).order_by("myid")
+
+        # NOTE: below the rendered params are just what
+        # render_postcompile will do right now
+        # if you run construct_params().  render_postcompile mode
+        # is not actually used by the execution internals, it's for
+        # user-facing compilation code.  So this is likely a
+        # current limitation of construct_params() which is not
+        # doing the full blown postcompile; just assert that's
+        # what it does for now.  it likely should be corrected
+        # to make more sense.
+
+        if paramstyle.qmark:
+            self.assert_compile(
+                stmt,
+                "SELECT mytable.myid, mytable.name, mytable.description "
+                "FROM mytable "
+                "WHERE (mytable.name IN (?, ?) OR "
+                "mytable.name IN (?)) "
+                "AND mytable.myid = ? "
+                "UNION SELECT mytable.myid, mytable.name, mytable.description "
+                "FROM mytable "
+                "WHERE (mytable.name IN (?, ?) OR "
+                "mytable.name IN (?)) "
+                "AND mytable.myid = ? ORDER BY myid",
+                params={"uname": ["y", "z"], "uname2": ["a"]},
+                checkpositional=(
+                    ["y", "z"],
+                    ["y", "z"],
+                    ["a"],
+                    8,
+                    ["y", "z"],
+                    ["y", "z"],
+                    ["a"],
+                    9,
+                ),
+                dialect="sqlite",
+                render_postcompile=True,
+            )
+        elif paramstyle.numeric:
+            self.assert_compile(
+                stmt,
+                "SELECT mytable.myid, mytable.name, mytable.description "
+                "FROM mytable "
+                "WHERE (mytable.name IN (:3, :4) OR "
+                "mytable.name IN (:5)) "
+                "AND mytable.myid = :1 "
+                "UNION SELECT mytable.myid, mytable.name, mytable.description "
+                "FROM mytable "
+                "WHERE (mytable.name IN (:3, :4) OR "
+                "mytable.name IN (:5)) "
+                "AND mytable.myid = :2 ORDER BY myid",
+                params={"uname": ["y", "z"], "uname2": ["a"]},
+                checkpositional=(8, 9, ["y", "z"], ["y", "z"], ["a"]),
+                dialect=sqlite.dialect(paramstyle="numeric"),
+                render_postcompile=True,
+            )
+        elif paramstyle.named:
+            self.assert_compile(
+                stmt,
+                "SELECT mytable.myid, mytable.name, mytable.description "
+                "FROM mytable "
+                "WHERE (mytable.name IN (:uname_1, :uname_2) OR "
+                "mytable.name IN (:uname2_1)) "
+                "AND mytable.myid = :myid_1 "
+                "UNION SELECT mytable.myid, mytable.name, mytable.description "
+                "FROM mytable "
+                "WHERE (mytable.name IN (:uname_1, :uname_2) OR "
+                "mytable.name IN (:uname2_1)) "
+                "AND mytable.myid = :myid_2 ORDER BY myid",
+                params={"uname": ["y", "z"], "uname2": ["a"]},
+                checkparams={
+                    "uname": ["y", "z"],
+                    "uname2": ["a"],
+                    "uname_1": ["y", "z"],
+                    "uname_2": ["y", "z"],
+                    "uname2_1": ["a"],
+                    "myid_1": 8,
+                    "myid_2": 9,
+                },
+                dialect="default",
+                render_postcompile=True,
+            )
+        else:
+            paramstyle.fail()
+
+    def test_numeric_dollar_bindparam(self):
+        stmt = table1.select().where(
+            table1.c.name == "a", table1.c.myid.in_([1, 2])
+        )
+        self.assert_compile(
+            stmt,
+            "SELECT mytable.myid, mytable.name, mytable.description "
+            "FROM mytable "
+            "WHERE mytable.name = $1 "
+            "AND mytable.myid IN ($2, $3)",
+            checkpositional=("a", 1, 2),
+            dialect=default.DefaultDialect(paramstyle="numeric_dollar"),
+            render_postcompile=True,
+        )
 
 
 class UnsupportedTest(fixtures.TestBase):
@@ -5096,6 +5268,28 @@ class StringifySpecialTest(fixtures.TestBase):
             "INSERT INTO mytable (myid) VALUES (:myid_m0), (:myid_m1)",
         )
 
+    def test_multirow_insert_positional(self):
+        stmt = table1.insert().values([{"myid": 1}, {"myid": 2}])
+        eq_ignore_whitespace(
+            stmt.compile(dialect=sqlite.dialect()).string,
+            "INSERT INTO mytable (myid) VALUES (?), (?)",
+        )
+
+    def test_multirow_insert_numeric(self):
+        stmt = table1.insert().values([{"myid": 1}, {"myid": 2}])
+        eq_ignore_whitespace(
+            stmt.compile(dialect=sqlite.dialect(paramstyle="numeric")).string,
+            "INSERT INTO mytable (myid) VALUES (:1), (:2)",
+        )
+
+    def test_insert_noparams_numeric(self):
+        ii = table1.insert().returning(table1.c.myid)
+        eq_ignore_whitespace(
+            ii.compile(dialect=sqlite.dialect(paramstyle="numeric")).string,
+            "INSERT INTO mytable (myid, name, description) VALUES "
+            "(:1, :2, :3) RETURNING myid",
+        )
+
     def test_cte(self):
         # stringify of these was supported anyway by defaultdialect.
         stmt = select(table1.c.myid).cte()
@@ -5153,6 +5347,42 @@ class StringifySpecialTest(fixtures.TestBase):
             "SELECT CAST(mytable.myid AS MyType()) AS myid FROM mytable",
         )
 
+    def test_dialect_sub_compile(self):
+        class Widget(ClauseElement):
+            __visit_name__ = "widget"
+            stringify_dialect = "sqlite"
+
+        def visit_widget(self, element, **kw):
+            return "widget"
+
+        with mock.patch(
+            "sqlalchemy.dialects.sqlite.base.SQLiteCompiler.visit_widget",
+            visit_widget,
+            create=True,
+        ):
+            eq_(str(Grouping(Widget())), "(widget)")
+
+    def test_dialect_sub_compile_w_binds(self):
+        """test sub-compile into a new compiler where
+        state != CompilerState.COMPILING, but we have to render a bindparam
+        string.  has to render the correct template.
+
+        """
+
+        class Widget(ClauseElement):
+            __visit_name__ = "widget"
+            stringify_dialect = "sqlite"
+
+        def visit_widget(self, element, **kw):
+            return f"widget {self.process(bindparam('q'), **kw)}"
+
+        with mock.patch(
+            "sqlalchemy.dialects.sqlite.base.SQLiteCompiler.visit_widget",
+            visit_widget,
+            create=True,
+        ):
+            eq_(str(Grouping(Widget())), "(widget ?)")
+
     def test_within_group(self):
         # stringify of these was supported anyway by defaultdialect.
         from sqlalchemy import within_group
index b89d18de62c51272961e9e686c11e11000d17554..502104daeadc6c2a6a6f3e9053889e28fb5b6f5a 100644 (file)
@@ -993,20 +993,20 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
         self.assert_compile(
             s,
             'WITH regional_sales AS (SELECT orders."order" '
-            'AS "order", :1 AS anon_2 FROM orders) SELECT '
-            'regional_sales."order", :2 AS anon_1 FROM regional_sales',
-            checkpositional=("x", "y"),
+            'AS "order", :2 AS anon_2 FROM orders) SELECT '
+            'regional_sales."order", :1 AS anon_1 FROM regional_sales',
+            checkpositional=("y", "x"),
             dialect=dialect,
         )
 
         self.assert_compile(
             s.union(s),
             'WITH regional_sales AS (SELECT orders."order" '
-            'AS "order", :1 AS anon_2 FROM orders) SELECT '
-            'regional_sales."order", :2 AS anon_1 FROM regional_sales '
-            'UNION SELECT regional_sales."order", :3 AS anon_1 '
+            'AS "order", :2 AS anon_2 FROM orders) SELECT '
+            'regional_sales."order", :1 AS anon_1 FROM regional_sales '
+            'UNION SELECT regional_sales."order", :1 AS anon_1 '
             "FROM regional_sales",
-            checkpositional=("x", "y", "y"),
+            checkpositional=("y", "x"),
             dialect=dialect,
         )
 
@@ -1057,8 +1057,8 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
         self.assert_compile(
             s3,
             'WITH regional_sales_1 AS (SELECT orders."order" AS "order" '
-            'FROM orders WHERE orders."order" = :1), regional_sales_2 AS '
-            '(SELECT orders."order" = :2 AS anon_1, '
+            'FROM orders WHERE orders."order" = :2), regional_sales_2 AS '
+            '(SELECT orders."order" = :1 AS anon_1, '
             'anon_2."order" AS "order", '
             'orders."order" AS order_1, '
             'regional_sales_1."order" AS order_2 FROM orders, '
@@ -1067,7 +1067,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
             'WHERE orders."order" = :3) SELECT regional_sales_2.anon_1, '
             'regional_sales_2."order", regional_sales_2.order_1, '
             "regional_sales_2.order_2 FROM regional_sales_2",
-            checkpositional=("x", "y", "z"),
+            checkpositional=("y", "x", "z"),
             dialect=dialect,
         )
 
index ac9ac4022bceb67bf2644dc7f2f3201a0aadd35e..1c24d4c79325ce3fb2f91d07fae89254a29da51d 100644 (file)
@@ -488,7 +488,8 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL):
             dialect=postgresql.dialect(),
         )
 
-    def test_heterogeneous_multi_values(self):
+    @testing.variation("paramstyle", ["pg", "qmark", "numeric", "dollar"])
+    def test_heterogeneous_multi_values(self, paramstyle):
         """for #6047, originally I thought we'd take any insert().values()
         and be able to convert it to a "many" style execution that we can
         cache.
@@ -519,33 +520,81 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL):
             ]
         )
 
+        pos_par = (
+            1,
+            1,
+            2,
+            2,
+            1,
+            2,
+            2,
+            3,
+            1,
+            2,
+            2,
+            10,
+        )
+
         # SQL expressions in the params at arbitrary locations means
         # we have to scan them at compile time, and the shape of the bound
         # parameters is not predictable.   so for #6047 where I originally
         # thought all of values() could be rewritten, this makes it not
         # really worth it.
-        self.assert_compile(
-            stmt,
-            "INSERT INTO t (x, y, z) VALUES "
-            "(%(x_m0)s, sum(%(sum_1)s, %(sum_2)s), %(z_m0)s), "
-            "(sum(%(sum_3)s, %(sum_4)s), %(y_m1)s, %(z_m1)s), "
-            "(sum(%(sum_5)s, %(sum_6)s), %(y_m2)s, foo(%(foo_1)s))",
-            checkparams={
-                "x_m0": 1,
-                "sum_1": 1,
-                "sum_2": 2,
-                "z_m0": 2,
-                "sum_3": 1,
-                "sum_4": 2,
-                "y_m1": 2,
-                "z_m1": 3,
-                "sum_5": 1,
-                "sum_6": 2,
-                "y_m2": 2,
-                "foo_1": 10,
-            },
-            dialect=postgresql.dialect(),
-        )
+        if paramstyle.pg:
+            self.assert_compile(
+                stmt,
+                "INSERT INTO t (x, y, z) VALUES "
+                "(%(x_m0)s, sum(%(sum_1)s, %(sum_2)s), %(z_m0)s), "
+                "(sum(%(sum_3)s, %(sum_4)s), %(y_m1)s, %(z_m1)s), "
+                "(sum(%(sum_5)s, %(sum_6)s), %(y_m2)s, foo(%(foo_1)s))",
+                checkparams={
+                    "x_m0": 1,
+                    "sum_1": 1,
+                    "sum_2": 2,
+                    "z_m0": 2,
+                    "sum_3": 1,
+                    "sum_4": 2,
+                    "y_m1": 2,
+                    "z_m1": 3,
+                    "sum_5": 1,
+                    "sum_6": 2,
+                    "y_m2": 2,
+                    "foo_1": 10,
+                },
+                dialect=postgresql.dialect(),
+            )
+        elif paramstyle.qmark:
+            self.assert_compile(
+                stmt,
+                "INSERT INTO t (x, y, z) VALUES "
+                "(?, sum(?, ?), ?), "
+                "(sum(?, ?), ?, ?), "
+                "(sum(?, ?), ?, foo(?))",
+                checkpositional=pos_par,
+                dialect=sqlite.dialect(),
+            )
+        elif paramstyle.numeric:
+            self.assert_compile(
+                stmt,
+                "INSERT INTO t (x, y, z) VALUES "
+                "(:1, sum(:2, :3), :4), "
+                "(sum(:5, :6), :7, :8), "
+                "(sum(:9, :10), :11, foo(:12))",
+                checkpositional=pos_par,
+                dialect=sqlite.dialect(paramstyle="numeric"),
+            )
+        elif paramstyle.dollar:
+            self.assert_compile(
+                stmt,
+                "INSERT INTO t (x, y, z) VALUES "
+                "($1, sum($2, $3), $4), "
+                "(sum($5, $6), $7, $8), "
+                "(sum($9, $10), $11, foo($12))",
+                checkpositional=pos_par,
+                dialect=sqlite.dialect(paramstyle="numeric_dollar"),
+            )
+        else:
+            paramstyle.fail()
 
     def test_insert_seq_pk_multi_values_seq_not_supported(self):
         m = MetaData()
index b856acfd32d34f792423449e04a27c3b68c2e9c6..7f1124c842a19b29e78c226d15c979a924458f52 100644 (file)
@@ -107,7 +107,7 @@ class CursorResultTest(fixtures.TablesTest):
             Column("y", String(50)),
         )
 
-    @testing.requires.insert_returning
+    @testing.requires.insert_executemany_returning
     def test_splice_horizontally(self, connection):
         users = self.tables.users
         addresses = self.tables.addresses
index 91413ff3597efdcc84c46b93a1854219f50be2ab..59519a5eccc048c6ab99597268253cd9aefd2d3e 100644 (file)
@@ -3322,7 +3322,7 @@ class ExpressionTest(
         elif expression_type.right_side:
             expr = (column("x", Integer) == Widget(52)).right
         else:
-            assert False
+            expression_type.fail()
 
         if secondary_adapt:
             is_(expr.type._type_affinity, String)
index cd7f992e22c8268888a83706ebcae3d6698da623..66971f64eb679f0ad4dbeba1fc2539232f3eae0a 100644 (file)
@@ -907,7 +907,8 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
             dialect=dialect,
         )
 
-    def test_update_bound_ordering(self):
+    @testing.variation("paramstyle", ["qmark", "format", "numeric"])
+    def test_update_bound_ordering(self, paramstyle):
         """test that bound parameters between the UPDATE and FROM clauses
         order correctly in different SQL compilation scenarios.
 
@@ -921,30 +922,47 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
             .values(name="foo")
         )
 
-        dialect = default.StrCompileDialect()
-        dialect.positional = True
-        self.assert_compile(
-            upd,
-            "UPDATE mytable SET name=:name FROM (SELECT "
-            "myothertable.otherid AS otherid, "
-            "myothertable.othername AS othername "
-            "FROM myothertable "
-            "WHERE myothertable.otherid = :otherid_1) AS anon_1 "
-            "WHERE mytable.name = anon_1.othername",
-            checkpositional=("foo", 5),
-            dialect=dialect,
-        )
+        if paramstyle.qmark:
 
-        self.assert_compile(
-            upd,
-            "UPDATE mytable, (SELECT myothertable.otherid AS otherid, "
-            "myothertable.othername AS othername "
-            "FROM myothertable "
-            "WHERE myothertable.otherid = %s) AS anon_1 SET mytable.name=%s "
-            "WHERE mytable.name = anon_1.othername",
-            checkpositional=(5, "foo"),
-            dialect=mysql.dialect(),
-        )
+            dialect = default.StrCompileDialect(paramstyle="qmark")
+            self.assert_compile(
+                upd,
+                "UPDATE mytable SET name=? FROM (SELECT "
+                "myothertable.otherid AS otherid, "
+                "myothertable.othername AS othername "
+                "FROM myothertable "
+                "WHERE myothertable.otherid = ?) AS anon_1 "
+                "WHERE mytable.name = anon_1.othername",
+                checkpositional=("foo", 5),
+                dialect=dialect,
+            )
+        elif paramstyle.format:
+            self.assert_compile(
+                upd,
+                "UPDATE mytable, (SELECT myothertable.otherid AS otherid, "
+                "myothertable.othername AS othername "
+                "FROM myothertable "
+                "WHERE myothertable.otherid = %s) AS anon_1 "
+                "SET mytable.name=%s "
+                "WHERE mytable.name = anon_1.othername",
+                checkpositional=(5, "foo"),
+                dialect=mysql.dialect(),
+            )
+        elif paramstyle.numeric:
+            dialect = default.StrCompileDialect(paramstyle="numeric")
+            self.assert_compile(
+                upd,
+                "UPDATE mytable SET name=:1 FROM (SELECT "
+                "myothertable.otherid AS otherid, "
+                "myothertable.othername AS othername "
+                "FROM myothertable "
+                "WHERE myothertable.otherid = :2) AS anon_1 "
+                "WHERE mytable.name = anon_1.othername",
+                checkpositional=("foo", 5),
+                dialect=dialect,
+            )
+        else:
+            paramstyle.fail()
 
 
 class UpdateFromCompileTest(
diff --git a/tox.ini b/tox.ini
index 50c39f610d5d569164cad4ce705505a046d55b0a..b748050525063717a88bd9e55fa0be101ded34b9 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -89,7 +89,7 @@ setenv=
     sqlite: SQLITE={env:TOX_SQLITE:--db sqlite}
     sqlite_file: SQLITE={env:TOX_SQLITE_FILE:--db sqlite_file}
 
-    py3{,7,8,9,10,11}-sqlite: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver aiosqlite}
+    py3{,7,8,9,10,11}-sqlite: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver pysqlite_numeric --dbdriver aiosqlite}
 
     py3{,7,8,9}-sqlite_file: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver aiosqlite --dbdriver pysqlcipher}