]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
support bind expressions w/ expanding IN; apply to psycopg2
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 13 Oct 2021 19:52:12 +0000 (15:52 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 15 Oct 2021 13:28:49 +0000 (09:28 -0400)
Fixed issue where "expanding IN" would fail to function correctly with
datatypes that use the :meth:`_types.TypeEngine.bind_expression` method,
where the method would need to be applied to each element of the
IN expression rather than the overall IN expression itself.

Fixed issue where IN expressions against a series of array elements, as can
be done with PostgreSQL, would fail to function correctly due to multiple
issues within the "expanding IN" feature of SQLAlchemy Core that was
standardized in version 1.4.  The psycopg2 dialect now makes use of the
:meth:`_types.TypeEngine.bind_expression` method with :class:`_types.ARRAY`
to portably apply the correct casts to elements.  The asyncpg dialect was
not affected by this issue as it applies bind-level casts at the driver
level rather than at the compiler level.

as part of this commit the "bind translate" feature has been
simplified and also applies to the names in the POSTCOMPILE tag to
accommodate for brackets.

Fixes: #7177
Change-Id: I08c703adb0a9bd6f5aeee5de3ff6f03cccdccdc5

doc/build/changelog/unreleased_14/7177.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/psycopg2.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/compiler.py
test/dialect/postgresql/test_types.py
test/sql/test_external_traversal.py
test/sql/test_type_expressions.py

diff --git a/doc/build/changelog/unreleased_14/7177.rst b/doc/build/changelog/unreleased_14/7177.rst
new file mode 100644 (file)
index 0000000..7766c83
--- /dev/null
@@ -0,0 +1,22 @@
+.. change::
+    :tags: sql, bug, regression
+    :tickets: 7177
+
+    Fixed issue where "expanding IN" would fail to function correctly with
+    datatypes that use the :meth:`_types.TypeEngine.bind_expression` method,
+    where the method would need to be applied to each element of the
+    IN expression rather than the overall IN expression itself.
+
+.. change::
+    :tags: postgresql, bug, regression
+    :tickets: 7177
+
+    Fixed issue where IN expressions against a series of array elements, as can
+    be done with PostgreSQL, would fail to function correctly due to multiple
+    issues within the "expanding IN" feature of SQLAlchemy Core that was
+    standardized in version 1.4.  The psycopg2 dialect now makes use of the
+    :meth:`_types.TypeEngine.bind_expression` method with :class:`_types.ARRAY`
+    to portably apply the correct casts to elements.  The asyncpg dialect was
+    not affected by this issue as it applies bind-level casts at the driver
+    level rather than at the compiler level.
+
index dc3da224ca25ab4d46699d100cfa596ee51eac0c..3d195e691ae9b1aaf688ba4843d63b816c98c9a8 100644 (file)
@@ -362,7 +362,6 @@ class AsyncAdapt_asyncpg_cursor:
         if not self._inputsizes:
             return tuple("$%d" % idx for idx, _ in enumerate(params, 1))
         else:
-
             return tuple(
                 "$%d::%s" % (idx, typ) if typ else "$%d" % idx
                 for idx, typ in enumerate(
index 2e28b45ca900a230d8db930d9e13908ec2c4cd24..c1a2cf81dcfda4cfe1fd2e01a5729b391f81f6b9 100644 (file)
@@ -2047,6 +2047,15 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
             self.drop(bind=bind, checkfirst=checkfirst)
 
 
+class _ColonCast(elements.Cast):
+    __visit_name__ = "colon_cast"
+
+    def __init__(self, expression, type_):
+        self.type = type_
+        self.clause = expression
+        self.typeclause = elements.TypeClause(type_)
+
+
 colspecs = {
     sqltypes.ARRAY: _array.ARRAY,
     sqltypes.Interval: INTERVAL,
@@ -2102,6 +2111,12 @@ ischema_names = {
 
 
 class PGCompiler(compiler.SQLCompiler):
+    def visit_colon_cast(self, element, **kw):
+        return "%s::%s" % (
+            element.clause._compiler_dispatch(self, **kw),
+            element.typeclause._compiler_dispatch(self, **kw),
+        )
+
     def visit_array(self, element, **kw):
         return "ARRAY[%s]" % self.visit_clauselist(element, **kw)
 
index a71bdf7606593919b29a9a8d350b9a1f2ad93f86..4143dd041d65cc88f1bf14812ad900f27303d147 100644 (file)
@@ -473,6 +473,8 @@ import logging
 import re
 from uuid import UUID as _python_UUID
 
+from .array import ARRAY as PGARRAY
+from .base import _ColonCast
 from .base import _DECIMAL_TYPES
 from .base import _FLOAT_TYPES
 from .base import _INT_TYPES
@@ -490,7 +492,6 @@ from ... import processors
 from ... import types as sqltypes
 from ... import util
 from ...engine import cursor as _cursor
-from ...sql import elements
 from ...util import collections_abc
 
 
@@ -556,6 +557,11 @@ class _PGHStore(HSTORE):
             return super(_PGHStore, self).result_processor(dialect, coltype)
 
 
+class _PGARRAY(PGARRAY):
+    def bind_expression(self, bindvalue):
+        return _ColonCast(bindvalue, self)
+
+
 class _PGJSON(JSON):
     def result_processor(self, dialect, coltype):
         return None
@@ -638,25 +644,7 @@ class PGExecutionContext_psycopg2(PGExecutionContext):
 
 
 class PGCompiler_psycopg2(PGCompiler):
-    def visit_bindparam(self, bindparam, skip_bind_expression=False, **kw):
-
-        text = super(PGCompiler_psycopg2, self).visit_bindparam(
-            bindparam, skip_bind_expression=skip_bind_expression, **kw
-        )
-        # note that if the type has a bind_expression(), we will get a
-        # double compile here
-        if not skip_bind_expression and (
-            bindparam.type._is_array or bindparam.type._is_type_decorator
-        ):
-            typ = bindparam.type._unwrapped_dialect_impl(self.dialect)
-
-            if typ._is_array:
-                text += "::%s" % (
-                    elements.TypeClause(typ)._compiler_dispatch(
-                        self, skip_bind_expression=skip_bind_expression, **kw
-                    ),
-                )
-        return text
+    pass
 
 
 class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer):
@@ -713,6 +701,7 @@ class PGDialect_psycopg2(PGDialect):
             sqltypes.JSON: _PGJSON,
             JSONB: _PGJSONB,
             UUID: _PGUUID,
+            sqltypes.ARRAY: _PGARRAY,
         },
     )
 
index eff28e34008a4e89e58f3d894d324f4ef3385115..75bca190502ba3053b71046cd8149e87a2990531 100644 (file)
@@ -1584,7 +1584,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
         from the bind parameter's ``TypeEngine`` objects.
 
         This method only called by those dialects which require it,
-        currently cx_oracle.
+        currently cx_oracle, asyncpg and pg8000.
 
         """
         if self.isddl or self.is_text:
index efcfe0e51c0791ad79aed260ceb1870ada813a08..0cd568fcc643bdb226ea6efaa98c844d442a35be 100644 (file)
@@ -165,11 +165,8 @@ BIND_TEMPLATES = {
     "named": ":%(name)s",
 }
 
-BIND_TRANSLATE = {
-    "pyformat": re.compile(r"[%\(\)]"),
-    "named": re.compile(r"[\:]"),
-}
-_BIND_TRANSLATE_CHARS = {"%": "P", "(": "A", ")": "Z", ":": "C"}
+_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\]]")
+_BIND_TRANSLATE_CHARS = dict(zip("%():[]", "PAZC__"))
 
 OPERATORS = {
     # binary
@@ -746,7 +743,6 @@ class SQLCompiler(Compiled):
             self.positiontup = []
             self._numeric_binds = dialect.paramstyle == "numeric"
         self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
-        self._bind_translate = BIND_TRANSLATE.get(dialect.paramstyle, None)
 
         self.ctes = None
 
@@ -1113,7 +1109,6 @@ class SQLCompiler(Compiled):
           N as a bound parameter.
 
         """
-
         if parameters is None:
             parameters = self.construct_params()
 
@@ -1141,22 +1136,36 @@ class SQLCompiler(Compiled):
         replacement_expressions = {}
         to_update_sets = {}
 
+        # notes:
+        # *unescaped* parameter names in:
+        # self.bind_names, self.binds, self._bind_processors
+        #
+        # *escaped* parameter names in:
+        # construct_params(), replacement_expressions
+
         for name in (
             self.positiontup if self.positional else self.bind_names.values()
         ):
+            escaped_name = (
+                self.escaped_bind_names.get(name, name)
+                if self.escaped_bind_names
+                else name
+            )
             parameter = self.binds[name]
             if parameter in self.literal_execute_params:
-                if name not in replacement_expressions:
-                    value = parameters.pop(name)
+                if escaped_name not in replacement_expressions:
+                    value = parameters.pop(escaped_name)
 
-                replacement_expressions[name] = self.render_literal_bindparam(
+                replacement_expressions[
+                    escaped_name
+                ] = self.render_literal_bindparam(
                     parameter, render_literal_value=value
                 )
                 continue
 
             if parameter in self.post_compile_params:
-                if name in replacement_expressions:
-                    to_update = to_update_sets[name]
+                if escaped_name in replacement_expressions:
+                    to_update = to_update_sets[escaped_name]
                 else:
                     # we are removing the parameter from parameters
                     # because it is a list value, which is not expected by
@@ -1164,13 +1173,15 @@ class SQLCompiler(Compiled):
                     # process it. the single name is being replaced with
                     # individual numbered parameters for each value in the
                     # param.
-                    values = parameters.pop(name)
+                    values = parameters.pop(escaped_name)
 
                     leep = self._literal_execute_expanding_parameter
-                    to_update, replacement_expr = leep(name, parameter, values)
+                    to_update, replacement_expr = leep(
+                        escaped_name, parameter, values
+                    )
 
-                    to_update_sets[name] = to_update
-                    replacement_expressions[name] = replacement_expr
+                    to_update_sets[escaped_name] = to_update
+                    replacement_expressions[escaped_name] = replacement_expr
 
                 if not parameter.literal_execute:
                     parameters.update(to_update)
@@ -1200,10 +1211,24 @@ class SQLCompiler(Compiled):
                 positiontup.append(name)
 
         def process_expanding(m):
-            return replacement_expressions[m.group(1)]
+            key = m.group(1)
+            expr = replacement_expressions[key]
+
+            # if POSTCOMPILE included a bind_expression, render that
+            # around each element
+            if m.group(2):
+                tok = m.group(2).split("~~")
+                be_left, be_right = tok[1], tok[3]
+                expr = ", ".join(
+                    "%s%s%s" % (be_left, exp, be_right)
+                    for exp in expr.split(", ")
+                )
+            return expr
 
         statement = re.sub(
-            r"\[POSTCOMPILE_(\S+)\]", process_expanding, self.string
+            r"\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]",
+            process_expanding,
+            self.string,
         )
 
         expanded_state = ExpandedState(
@@ -1963,8 +1988,10 @@ class SQLCompiler(Compiled):
         self, parameter, values
     ):
 
+        typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect)
+
         if not values:
-            if parameter.type._is_tuple_type:
+            if typ_dialect_impl._is_tuple_type:
                 replacement_expression = (
                     "VALUES " if self.dialect.tuple_in_values else ""
                 ) + self.visit_empty_set_op_expr(
@@ -1977,7 +2004,7 @@ class SQLCompiler(Compiled):
                 )
 
         elif isinstance(values[0], (tuple, list)):
-            assert parameter.type._is_tuple_type
+            assert typ_dialect_impl._is_tuple_type
             replacement_expression = (
                 "VALUES " if self.dialect.tuple_in_values else ""
             ) + ", ".join(
@@ -1993,7 +2020,7 @@ class SQLCompiler(Compiled):
                 for i, tuple_element in enumerate(values)
             )
         else:
-            assert not parameter.type._is_tuple_type
+            assert not typ_dialect_impl._is_tuple_type
             replacement_expression = ", ".join(
                 self.render_literal_value(value, parameter.type)
                 for value in values
@@ -2008,9 +2035,11 @@ class SQLCompiler(Compiled):
                 parameter, values
             )
 
+        typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect)
+
         if not values:
             to_update = []
-            if parameter.type._is_tuple_type:
+            if typ_dialect_impl._is_tuple_type:
 
                 replacement_expression = self.visit_empty_set_op_expr(
                     parameter.type.types, parameter.expand_op
@@ -2020,7 +2049,10 @@ class SQLCompiler(Compiled):
                     [parameter.type], parameter.expand_op
                 )
 
-        elif isinstance(values[0], (tuple, list)):
+        elif (
+            isinstance(values[0], (tuple, list))
+            and not typ_dialect_impl._is_array
+        ):
             to_update = [
                 ("%s_%s_%s" % (name, i, j), value)
                 for i, tuple_element in enumerate(values, 1)
@@ -2299,14 +2331,27 @@ class SQLCompiler(Compiled):
             impl = bindparam.type.dialect_impl(self.dialect)
             if impl._has_bind_expression:
                 bind_expression = impl.bind_expression(bindparam)
-                return self.process(
+                wrapped = self.process(
                     bind_expression,
                     skip_bind_expression=True,
                     within_columns_clause=within_columns_clause,
                     literal_binds=literal_binds,
                     literal_execute=literal_execute,
+                    render_postcompile=render_postcompile,
                     **kwargs
                 )
+                if bindparam.expanding:
+                    # for postcompile w/ expanding, move the "wrapped" part
+                    # of this into the inside
+                    m = re.match(
+                        r"^(.*)\(\[POSTCOMPILE_(\S+?)\]\)(.*)$", wrapped
+                    )
+                    wrapped = "([POSTCOMPILE_%s~~%s~~REPL~~%s~~])" % (
+                        m.group(2),
+                        m.group(1),
+                        m.group(3),
+                    )
+                return wrapped
 
         if not literal_binds:
             literal_execute = (
@@ -2489,12 +2534,13 @@ class SQLCompiler(Compiled):
                 positional_names.append(name)
             else:
                 self.positiontup.append(name)
-        elif not post_compile and not escaped_from:
-            tr_reg = self._bind_translate
-            if tr_reg.search(name):
-                # i'd rather use translate() here but I can't get it to work
-                # in all cases under Python 2, not worth it right now
-                new_name = tr_reg.sub(
+        elif not escaped_from:
+
+            if _BIND_TRANSLATE_RE.search(name):
+                # not quite the translate use case as we want to
+                # also get a quick boolean if we even found
+                # unusual characters in the name
+                new_name = _BIND_TRANSLATE_RE.sub(
                     lambda m: _BIND_TRANSLATE_CHARS[m.group(0)],
                     name,
                 )
index 92641fcc601b0ae0a40b64308f3321587029aa26..dd0a1be0f306e3963888a4f0a905a31a207bb23f 100644 (file)
@@ -1198,6 +1198,45 @@ class ArrayTest(AssertsCompiledSQL, fixtures.TestBase):
             postgresql.ARRAY(Unicode(30), dimensions=3), "VARCHAR(30)[][][]"
         )
 
+    def test_array_in_enum_psycopg2_cast(self):
+        expr = column(
+            "x",
+            postgresql.ARRAY(
+                postgresql.ENUM("one", "two", "three", name="myenum")
+            ),
+        ).in_([["one", "two"], ["three", "four"]])
+
+        self.assert_compile(
+            expr,
+            "x IN ([POSTCOMPILE_x_1~~~~REPL~~::myenum[]~~])",
+            dialect=postgresql.psycopg2.dialect(),
+        )
+
+        self.assert_compile(
+            expr,
+            "x IN (%(x_1_1)s::myenum[], %(x_1_2)s::myenum[])",
+            dialect=postgresql.psycopg2.dialect(),
+            render_postcompile=True,
+        )
+
+    def test_array_in_str_psycopg2_cast(self):
+        expr = column("x", postgresql.ARRAY(String(15))).in_(
+            [["one", "two"], ["three", "four"]]
+        )
+
+        self.assert_compile(
+            expr,
+            "x IN ([POSTCOMPILE_x_1~~~~REPL~~::VARCHAR(15)[]~~])",
+            dialect=postgresql.psycopg2.dialect(),
+        )
+
+        self.assert_compile(
+            expr,
+            "x IN (%(x_1_1)s::VARCHAR(15)[], %(x_1_2)s::VARCHAR(15)[])",
+            dialect=postgresql.psycopg2.dialect(),
+            render_postcompile=True,
+        )
+
     def test_array_type_render_str_collate_multidim(self):
         self.assert_compile(
             postgresql.ARRAY(Unicode(30, collation="en_US"), dimensions=2),
@@ -1457,11 +1496,79 @@ class ArrayRoundTripTest(object):
         t = Table(
             "t",
             metadata,
-            Column("data", sqltypes.ARRAY(String(50, collation="en_US"))),
+            Column("data", self.ARRAY(String(50, collation="en_US"))),
         )
 
         t.create(connection)
 
+    @testing.fixture
+    def array_in_fixture(self, connection):
+        arrtable = self.tables.arrtable
+
+        connection.execute(
+            arrtable.insert(),
+            [
+                {
+                    "id": 1,
+                    "intarr": [1, 2, 3],
+                    "strarr": [u"one", u"two", u"three"],
+                },
+                {
+                    "id": 2,
+                    "intarr": [4, 5, 6],
+                    "strarr": [u"four", u"five", u"six"],
+                },
+                {"id": 3, "intarr": [1, 5], "strarr": [u"one", u"five"]},
+                {"id": 4, "intarr": [], "strarr": []},
+            ],
+        )
+
+    def test_array_in_int(self, array_in_fixture, connection):
+        """test #7177"""
+
+        arrtable = self.tables.arrtable
+
+        stmt = (
+            select(arrtable.c.intarr)
+            .where(arrtable.c.intarr.in_([[1, 5], [4, 5, 6], [9, 10]]))
+            .order_by(arrtable.c.id)
+        )
+
+        eq_(
+            connection.execute(stmt).all(),
+            [
+                ([4, 5, 6],),
+                ([1, 5],),
+            ],
+        )
+
+    def test_array_in_str(self, array_in_fixture, connection):
+        """test #7177"""
+
+        arrtable = self.tables.arrtable
+
+        stmt = (
+            select(arrtable.c.strarr)
+            .where(
+                arrtable.c.strarr.in_(
+                    [
+                        [u"one", u"five"],
+                        [u"four", u"five", u"six"],
+                        [u"nine", u"ten"],
+                    ]
+                )
+            )
+            .order_by(arrtable.c.id)
+        )
+
+        eq_(
+            connection.execute(stmt).all(),
+            [
+                (["four", "five", "six"],),
+                (["one", "five"],),
+            ],
+        )
+
     def test_array_agg(self, metadata, connection):
         values_table = Table("values", metadata, Column("value", Integer))
         metadata.create_all(connection)
@@ -2151,6 +2258,9 @@ class _ArrayOfEnum(TypeDecorator):
     impl = postgresql.ARRAY
     cache_ok = True
 
+    # note expanding logic is checking _is_array here so that has to
+    # translate through the TypeDecorator
+
     def bind_expression(self, bindvalue):
         return sa.cast(bindvalue, self)
 
@@ -2207,56 +2317,93 @@ class ArrayEnum(fixtures.TestBase):
             connection,
         )
 
-    @testing.combinations(sqltypes.Enum, postgresql.ENUM, argnames="enum_cls")
-    @testing.combinations(
-        sqltypes.ARRAY,
-        postgresql.ARRAY,
-        (_ArrayOfEnum, testing.only_on("postgresql+psycopg2")),
-        argnames="array_cls",
-    )
-    def test_array_of_enums(self, array_cls, enum_cls, metadata, connection):
-        tbl = Table(
-            "enum_table",
-            self.metadata,
-            Column("id", Integer, primary_key=True),
-            Column(
-                "enum_col",
-                array_cls(enum_cls("foo", "bar", "baz", name="an_enum")),
-            ),
-        )
-
-        if util.py3k:
-            from enum import Enum
-
-            class MyEnum(Enum):
-                a = "aaa"
-                b = "bbb"
-                c = "ccc"
-
-            tbl.append_column(
+    @testing.fixture
+    def array_of_enum_fixture(self, metadata, connection):
+        def go(array_cls, enum_cls):
+            tbl = Table(
+                "enum_table",
+                metadata,
+                Column("id", Integer, primary_key=True),
                 Column(
-                    "pyenum_col",
-                    array_cls(enum_cls(MyEnum)),
+                    "enum_col",
+                    array_cls(enum_cls("foo", "bar", "baz", name="an_enum")),
                 ),
             )
+            if util.py3k:
+                from enum import Enum
+
+                class MyEnum(Enum):
+                    a = "aaa"
+                    b = "bbb"
+                    c = "ccc"
+
+                tbl.append_column(
+                    Column(
+                        "pyenum_col",
+                        array_cls(enum_cls(MyEnum)),
+                    ),
+                )
+            else:
+                MyEnum = None
 
-        self.metadata.create_all(connection)
+            metadata.create_all(connection)
+            connection.execute(
+                tbl.insert(),
+                [{"enum_col": ["foo"]}, {"enum_col": ["foo", "bar"]}],
+            )
+            return tbl, MyEnum
 
-        connection.execute(
-            tbl.insert(), [{"enum_col": ["foo"]}, {"enum_col": ["foo", "bar"]}]
+        yield go
+
+    def _enum_combinations(fn):
+        return testing.combinations(
+            sqltypes.Enum, postgresql.ENUM, argnames="enum_cls"
+        )(
+            testing.combinations(
+                sqltypes.ARRAY,
+                postgresql.ARRAY,
+                (_ArrayOfEnum, testing.only_on("postgresql+psycopg2")),
+                argnames="array_cls",
+            )(fn)
         )
 
+    @_enum_combinations
+    def test_array_of_enums_roundtrip(
+        self, array_of_enum_fixture, connection, array_cls, enum_cls
+    ):
+        tbl, MyEnum = array_of_enum_fixture(array_cls, enum_cls)
+
+        # test select back
         sel = select(tbl.c.enum_col).order_by(tbl.c.id)
         eq_(
             connection.execute(sel).fetchall(), [(["foo"],), (["foo", "bar"],)]
         )
 
-        if util.py3k:
-            connection.execute(tbl.insert(), {"pyenum_col": [MyEnum.a]})
-            sel = select(tbl.c.pyenum_col).order_by(tbl.c.id.desc())
-            eq_(connection.scalar(sel), [MyEnum.a])
+    @_enum_combinations
+    def test_array_of_enums_expanding_in(
+        self, array_of_enum_fixture, connection, array_cls, enum_cls
+    ):
+        tbl, MyEnum = array_of_enum_fixture(array_cls, enum_cls)
+
+        # test select with WHERE using expanding IN against arrays
+        # #7177
+        sel = (
+            select(tbl.c.enum_col)
+            .where(tbl.c.enum_col.in_([["foo", "bar"], ["bar", "foo"]]))
+            .order_by(tbl.c.id)
+        )
+        eq_(connection.execute(sel).fetchall(), [(["foo", "bar"],)])
+
+    @_enum_combinations
+    @testing.requires.python3
+    def test_array_of_enums_native_roundtrip(
+        self, array_of_enum_fixture, connection, array_cls, enum_cls
+    ):
+        tbl, MyEnum = array_of_enum_fixture(array_cls, enum_cls)
 
-        self.metadata.drop_all(connection)
+        connection.execute(tbl.insert(), {"pyenum_col": [MyEnum.a]})
+        sel = select(tbl.c.pyenum_col).order_by(tbl.c.id.desc())
+        eq_(connection.scalar(sel), [MyEnum.a])
 
 
 class ArrayJSON(fixtures.TestBase):
index 3d1b4fe85ec1b889bafcc574744e71cb45b2b782..0d43448d5ed3c6e19e36cbaa3ca43d737361ed58 100644 (file)
@@ -188,7 +188,10 @@ class TraversalTest(
         ("clone",), ("pickle",), ("conv_to_unique"), ("none"), argnames="meth"
     )
     @testing.combinations(
-        ("name with space",), ("name with [brackets]",), argnames="name"
+        ("name with space",),
+        ("name with [brackets]",),
+        ("name with~~tildes~~",),
+        argnames="name",
     )
     def test_bindparam_key_proc_for_copies(self, meth, name):
         r"""test :ticket:`6249`.
@@ -199,7 +202,7 @@ class TraversalTest(
 
         Currently, the bind key reg is::
 
-            re.sub(r"[%\(\) \$]+", "_", body).strip("_")
+            re.sub(r"[%\(\) \$\[\]]", "_", name)
 
         and the compiler postcompile reg is::
 
@@ -218,7 +221,8 @@ class TraversalTest(
             expr.right.unique = False
             expr.right._convert_to_unique()
 
-        token = re.sub(r"[%\(\) \$]+", "_", name).strip("_")
+        token = re.sub(r"[%\(\) \$\[\]]", "_", name)
+
         self.assert_compile(
             expr,
             '"%(name)s" IN (:%(token)s_1_1, '
index 51ee0ae62905f0e37ec6aaed9fa421a48bc9a2db..adcaef39cb4fbb71ec27f817c9265513031d0bd4 100644 (file)
@@ -182,6 +182,29 @@ class SelectTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL):
             "test_table WHERE test_table.y = lower(:y_1)",
         )
 
+    def test_in_binds(self):
+        table = self._fixture()
+
+        self.assert_compile(
+            select(table).where(
+                table.c.y.in_(["hi", "there", "some", "expr"])
+            ),
+            "SELECT test_table.x, lower(test_table.y) AS y FROM "
+            "test_table WHERE test_table.y IN "
+            "([POSTCOMPILE_y_1~~lower(~~REPL~~)~~])",
+            render_postcompile=False,
+        )
+
+        self.assert_compile(
+            select(table).where(
+                table.c.y.in_(["hi", "there", "some", "expr"])
+            ),
+            "SELECT test_table.x, lower(test_table.y) AS y FROM "
+            "test_table WHERE test_table.y IN "
+            "(lower(:y_1_1), lower(:y_1_2), lower(:y_1_3), lower(:y_1_4))",
+            render_postcompile=True,
+        )
+
     def test_dialect(self):
         table = self._fixture()
         dialect = self._dialect_level_fixture()