--- /dev/null
+.. change::
+ :tags: sql, bug, regression
+ :tickets: 7177
+
+ Fixed issue where "expanding IN" would fail to function correctly with
+ datatypes that use the :meth:`_types.TypeEngine.bind_expression` method,
+ where the method would need to be applied to each element of the
+ IN expression rather than the overall IN expression itself.
+
+.. change::
+ :tags: postgresql, bug, regression
+ :tickets: 7177
+
+ Fixed issue where IN expressions against a series of array elements, as can
+ be done with PostgreSQL, would fail to function correctly due to multiple
+ issues within the "expanding IN" feature of SQLAlchemy Core that was
+ standardized in version 1.4. The psycopg2 dialect now makes use of the
+ :meth:`_types.TypeEngine.bind_expression` method with :class:`_types.ARRAY`
+ to portably apply the correct casts to elements. The asyncpg dialect was
+ not affected by this issue as it applies bind-level casts at the driver
+ level rather than at the compiler level.
+
if not self._inputsizes:
return tuple("$%d" % idx for idx, _ in enumerate(params, 1))
else:
-
return tuple(
"$%d::%s" % (idx, typ) if typ else "$%d" % idx
for idx, typ in enumerate(
self.drop(bind=bind, checkfirst=checkfirst)
+class _ColonCast(elements.Cast):
+ __visit_name__ = "colon_cast"
+
+ def __init__(self, expression, type_):
+ self.type = type_
+ self.clause = expression
+ self.typeclause = elements.TypeClause(type_)
+
+
colspecs = {
sqltypes.ARRAY: _array.ARRAY,
sqltypes.Interval: INTERVAL,
class PGCompiler(compiler.SQLCompiler):
+ def visit_colon_cast(self, element, **kw):
+ return "%s::%s" % (
+ element.clause._compiler_dispatch(self, **kw),
+ element.typeclause._compiler_dispatch(self, **kw),
+ )
+
def visit_array(self, element, **kw):
return "ARRAY[%s]" % self.visit_clauselist(element, **kw)
import re
from uuid import UUID as _python_UUID
+from .array import ARRAY as PGARRAY
+from .base import _ColonCast
from .base import _DECIMAL_TYPES
from .base import _FLOAT_TYPES
from .base import _INT_TYPES
from ... import types as sqltypes
from ... import util
from ...engine import cursor as _cursor
-from ...sql import elements
from ...util import collections_abc
return super(_PGHStore, self).result_processor(dialect, coltype)
+class _PGARRAY(PGARRAY):
+ def bind_expression(self, bindvalue):
+ return _ColonCast(bindvalue, self)
+
+
class _PGJSON(JSON):
def result_processor(self, dialect, coltype):
return None
class PGCompiler_psycopg2(PGCompiler):
- def visit_bindparam(self, bindparam, skip_bind_expression=False, **kw):
-
- text = super(PGCompiler_psycopg2, self).visit_bindparam(
- bindparam, skip_bind_expression=skip_bind_expression, **kw
- )
- # note that if the type has a bind_expression(), we will get a
- # double compile here
- if not skip_bind_expression and (
- bindparam.type._is_array or bindparam.type._is_type_decorator
- ):
- typ = bindparam.type._unwrapped_dialect_impl(self.dialect)
-
- if typ._is_array:
- text += "::%s" % (
- elements.TypeClause(typ)._compiler_dispatch(
- self, skip_bind_expression=skip_bind_expression, **kw
- ),
- )
- return text
+ pass
class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer):
sqltypes.JSON: _PGJSON,
JSONB: _PGJSONB,
UUID: _PGUUID,
+ sqltypes.ARRAY: _PGARRAY,
},
)
from the bind parameter's ``TypeEngine`` objects.
This method only called by those dialects which require it,
- currently cx_oracle.
+ currently cx_oracle, asyncpg and pg8000.
"""
if self.isddl or self.is_text:
"named": ":%(name)s",
}
-BIND_TRANSLATE = {
- "pyformat": re.compile(r"[%\(\)]"),
- "named": re.compile(r"[\:]"),
-}
-_BIND_TRANSLATE_CHARS = {"%": "P", "(": "A", ")": "Z", ":": "C"}
+_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\]]")
+_BIND_TRANSLATE_CHARS = dict(zip("%():[]", "PAZC__"))
OPERATORS = {
# binary
self.positiontup = []
self._numeric_binds = dialect.paramstyle == "numeric"
self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
- self._bind_translate = BIND_TRANSLATE.get(dialect.paramstyle, None)
self.ctes = None
N as a bound parameter.
"""
-
if parameters is None:
parameters = self.construct_params()
replacement_expressions = {}
to_update_sets = {}
+ # notes:
+ # *unescaped* parameter names in:
+ # self.bind_names, self.binds, self._bind_processors
+ #
+ # *escaped* parameter names in:
+ # construct_params(), replacement_expressions
+
for name in (
self.positiontup if self.positional else self.bind_names.values()
):
+ escaped_name = (
+ self.escaped_bind_names.get(name, name)
+ if self.escaped_bind_names
+ else name
+ )
parameter = self.binds[name]
if parameter in self.literal_execute_params:
- if name not in replacement_expressions:
- value = parameters.pop(name)
+ if escaped_name not in replacement_expressions:
+ value = parameters.pop(escaped_name)
- replacement_expressions[name] = self.render_literal_bindparam(
+ replacement_expressions[
+ escaped_name
+ ] = self.render_literal_bindparam(
parameter, render_literal_value=value
)
continue
if parameter in self.post_compile_params:
- if name in replacement_expressions:
- to_update = to_update_sets[name]
+ if escaped_name in replacement_expressions:
+ to_update = to_update_sets[escaped_name]
else:
# we are removing the parameter from parameters
# because it is a list value, which is not expected by
# process it. the single name is being replaced with
# individual numbered parameters for each value in the
# param.
- values = parameters.pop(name)
+ values = parameters.pop(escaped_name)
leep = self._literal_execute_expanding_parameter
- to_update, replacement_expr = leep(name, parameter, values)
+ to_update, replacement_expr = leep(
+ escaped_name, parameter, values
+ )
- to_update_sets[name] = to_update
- replacement_expressions[name] = replacement_expr
+ to_update_sets[escaped_name] = to_update
+ replacement_expressions[escaped_name] = replacement_expr
if not parameter.literal_execute:
parameters.update(to_update)
positiontup.append(name)
def process_expanding(m):
- return replacement_expressions[m.group(1)]
+ key = m.group(1)
+ expr = replacement_expressions[key]
+
+ # if POSTCOMPILE included a bind_expression, render that
+ # around each element
+ if m.group(2):
+ tok = m.group(2).split("~~")
+ be_left, be_right = tok[1], tok[3]
+ expr = ", ".join(
+ "%s%s%s" % (be_left, exp, be_right)
+ for exp in expr.split(", ")
+ )
+ return expr
statement = re.sub(
- r"\[POSTCOMPILE_(\S+)\]", process_expanding, self.string
+ r"\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]",
+ process_expanding,
+ self.string,
)
expanded_state = ExpandedState(
self, parameter, values
):
+ typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect)
+
if not values:
- if parameter.type._is_tuple_type:
+ if typ_dialect_impl._is_tuple_type:
replacement_expression = (
"VALUES " if self.dialect.tuple_in_values else ""
) + self.visit_empty_set_op_expr(
)
elif isinstance(values[0], (tuple, list)):
- assert parameter.type._is_tuple_type
+ assert typ_dialect_impl._is_tuple_type
replacement_expression = (
"VALUES " if self.dialect.tuple_in_values else ""
) + ", ".join(
for i, tuple_element in enumerate(values)
)
else:
- assert not parameter.type._is_tuple_type
+ assert not typ_dialect_impl._is_tuple_type
replacement_expression = ", ".join(
self.render_literal_value(value, parameter.type)
for value in values
parameter, values
)
+ typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect)
+
if not values:
to_update = []
- if parameter.type._is_tuple_type:
+ if typ_dialect_impl._is_tuple_type:
replacement_expression = self.visit_empty_set_op_expr(
parameter.type.types, parameter.expand_op
[parameter.type], parameter.expand_op
)
- elif isinstance(values[0], (tuple, list)):
+ elif (
+ isinstance(values[0], (tuple, list))
+ and not typ_dialect_impl._is_array
+ ):
to_update = [
("%s_%s_%s" % (name, i, j), value)
for i, tuple_element in enumerate(values, 1)
impl = bindparam.type.dialect_impl(self.dialect)
if impl._has_bind_expression:
bind_expression = impl.bind_expression(bindparam)
- return self.process(
+ wrapped = self.process(
bind_expression,
skip_bind_expression=True,
within_columns_clause=within_columns_clause,
literal_binds=literal_binds,
literal_execute=literal_execute,
+ render_postcompile=render_postcompile,
**kwargs
)
+ if bindparam.expanding:
+ # for postcompile w/ expanding, move the "wrapped" part
+ # of this into the inside
+ m = re.match(
+ r"^(.*)\(\[POSTCOMPILE_(\S+?)\]\)(.*)$", wrapped
+ )
+ wrapped = "([POSTCOMPILE_%s~~%s~~REPL~~%s~~])" % (
+ m.group(2),
+ m.group(1),
+ m.group(3),
+ )
+ return wrapped
if not literal_binds:
literal_execute = (
positional_names.append(name)
else:
self.positiontup.append(name)
- elif not post_compile and not escaped_from:
- tr_reg = self._bind_translate
- if tr_reg.search(name):
- # i'd rather use translate() here but I can't get it to work
- # in all cases under Python 2, not worth it right now
- new_name = tr_reg.sub(
+ elif not escaped_from:
+
+ if _BIND_TRANSLATE_RE.search(name):
+ # not quite the translate use case as we want to
+ # also get a quick boolean if we even found
+ # unusual characters in the name
+ new_name = _BIND_TRANSLATE_RE.sub(
lambda m: _BIND_TRANSLATE_CHARS[m.group(0)],
name,
)
postgresql.ARRAY(Unicode(30), dimensions=3), "VARCHAR(30)[][][]"
)
+ def test_array_in_enum_psycopg2_cast(self):
+ expr = column(
+ "x",
+ postgresql.ARRAY(
+ postgresql.ENUM("one", "two", "three", name="myenum")
+ ),
+ ).in_([["one", "two"], ["three", "four"]])
+
+ self.assert_compile(
+ expr,
+ "x IN ([POSTCOMPILE_x_1~~~~REPL~~::myenum[]~~])",
+ dialect=postgresql.psycopg2.dialect(),
+ )
+
+ self.assert_compile(
+ expr,
+ "x IN (%(x_1_1)s::myenum[], %(x_1_2)s::myenum[])",
+ dialect=postgresql.psycopg2.dialect(),
+ render_postcompile=True,
+ )
+
+ def test_array_in_str_psycopg2_cast(self):
+ expr = column("x", postgresql.ARRAY(String(15))).in_(
+ [["one", "two"], ["three", "four"]]
+ )
+
+ self.assert_compile(
+ expr,
+ "x IN ([POSTCOMPILE_x_1~~~~REPL~~::VARCHAR(15)[]~~])",
+ dialect=postgresql.psycopg2.dialect(),
+ )
+
+ self.assert_compile(
+ expr,
+ "x IN (%(x_1_1)s::VARCHAR(15)[], %(x_1_2)s::VARCHAR(15)[])",
+ dialect=postgresql.psycopg2.dialect(),
+ render_postcompile=True,
+ )
+
def test_array_type_render_str_collate_multidim(self):
self.assert_compile(
postgresql.ARRAY(Unicode(30, collation="en_US"), dimensions=2),
t = Table(
"t",
metadata,
- Column("data", sqltypes.ARRAY(String(50, collation="en_US"))),
+ Column("data", self.ARRAY(String(50, collation="en_US"))),
)
t.create(connection)
+ @testing.fixture
+ def array_in_fixture(self, connection):
+ arrtable = self.tables.arrtable
+
+ connection.execute(
+ arrtable.insert(),
+ [
+ {
+ "id": 1,
+ "intarr": [1, 2, 3],
+ "strarr": [u"one", u"two", u"three"],
+ },
+ {
+ "id": 2,
+ "intarr": [4, 5, 6],
+ "strarr": [u"four", u"five", u"six"],
+ },
+ {"id": 3, "intarr": [1, 5], "strarr": [u"one", u"five"]},
+ {"id": 4, "intarr": [], "strarr": []},
+ ],
+ )
+
+ def test_array_in_int(self, array_in_fixture, connection):
+ """test #7177"""
+
+ arrtable = self.tables.arrtable
+
+ stmt = (
+ select(arrtable.c.intarr)
+ .where(arrtable.c.intarr.in_([[1, 5], [4, 5, 6], [9, 10]]))
+ .order_by(arrtable.c.id)
+ )
+
+ eq_(
+ connection.execute(stmt).all(),
+ [
+ ([4, 5, 6],),
+ ([1, 5],),
+ ],
+ )
+
+ def test_array_in_str(self, array_in_fixture, connection):
+ """test #7177"""
+
+ arrtable = self.tables.arrtable
+
+ stmt = (
+ select(arrtable.c.strarr)
+ .where(
+ arrtable.c.strarr.in_(
+ [
+ [u"one", u"five"],
+ [u"four", u"five", u"six"],
+ [u"nine", u"ten"],
+ ]
+ )
+ )
+ .order_by(arrtable.c.id)
+ )
+
+ eq_(
+ connection.execute(stmt).all(),
+ [
+ (["four", "five", "six"],),
+ (["one", "five"],),
+ ],
+ )
+
def test_array_agg(self, metadata, connection):
values_table = Table("values", metadata, Column("value", Integer))
metadata.create_all(connection)
impl = postgresql.ARRAY
cache_ok = True
+ # note expanding logic is checking _is_array here so that has to
+ # translate through the TypeDecorator
+
def bind_expression(self, bindvalue):
return sa.cast(bindvalue, self)
connection,
)
- @testing.combinations(sqltypes.Enum, postgresql.ENUM, argnames="enum_cls")
- @testing.combinations(
- sqltypes.ARRAY,
- postgresql.ARRAY,
- (_ArrayOfEnum, testing.only_on("postgresql+psycopg2")),
- argnames="array_cls",
- )
- def test_array_of_enums(self, array_cls, enum_cls, metadata, connection):
- tbl = Table(
- "enum_table",
- self.metadata,
- Column("id", Integer, primary_key=True),
- Column(
- "enum_col",
- array_cls(enum_cls("foo", "bar", "baz", name="an_enum")),
- ),
- )
-
- if util.py3k:
- from enum import Enum
-
- class MyEnum(Enum):
- a = "aaa"
- b = "bbb"
- c = "ccc"
-
- tbl.append_column(
+ @testing.fixture
+ def array_of_enum_fixture(self, metadata, connection):
+ def go(array_cls, enum_cls):
+ tbl = Table(
+ "enum_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
Column(
- "pyenum_col",
- array_cls(enum_cls(MyEnum)),
+ "enum_col",
+ array_cls(enum_cls("foo", "bar", "baz", name="an_enum")),
),
)
+ if util.py3k:
+ from enum import Enum
+
+ class MyEnum(Enum):
+ a = "aaa"
+ b = "bbb"
+ c = "ccc"
+
+ tbl.append_column(
+ Column(
+ "pyenum_col",
+ array_cls(enum_cls(MyEnum)),
+ ),
+ )
+ else:
+ MyEnum = None
- self.metadata.create_all(connection)
+ metadata.create_all(connection)
+ connection.execute(
+ tbl.insert(),
+ [{"enum_col": ["foo"]}, {"enum_col": ["foo", "bar"]}],
+ )
+ return tbl, MyEnum
- connection.execute(
- tbl.insert(), [{"enum_col": ["foo"]}, {"enum_col": ["foo", "bar"]}]
+ yield go
+
+ def _enum_combinations(fn):
+ return testing.combinations(
+ sqltypes.Enum, postgresql.ENUM, argnames="enum_cls"
+ )(
+ testing.combinations(
+ sqltypes.ARRAY,
+ postgresql.ARRAY,
+ (_ArrayOfEnum, testing.only_on("postgresql+psycopg2")),
+ argnames="array_cls",
+ )(fn)
)
+ @_enum_combinations
+ def test_array_of_enums_roundtrip(
+ self, array_of_enum_fixture, connection, array_cls, enum_cls
+ ):
+ tbl, MyEnum = array_of_enum_fixture(array_cls, enum_cls)
+
+ # test select back
sel = select(tbl.c.enum_col).order_by(tbl.c.id)
eq_(
connection.execute(sel).fetchall(), [(["foo"],), (["foo", "bar"],)]
)
- if util.py3k:
- connection.execute(tbl.insert(), {"pyenum_col": [MyEnum.a]})
- sel = select(tbl.c.pyenum_col).order_by(tbl.c.id.desc())
- eq_(connection.scalar(sel), [MyEnum.a])
+ @_enum_combinations
+ def test_array_of_enums_expanding_in(
+ self, array_of_enum_fixture, connection, array_cls, enum_cls
+ ):
+ tbl, MyEnum = array_of_enum_fixture(array_cls, enum_cls)
+
+ # test select with WHERE using expanding IN against arrays
+ # #7177
+ sel = (
+ select(tbl.c.enum_col)
+ .where(tbl.c.enum_col.in_([["foo", "bar"], ["bar", "foo"]]))
+ .order_by(tbl.c.id)
+ )
+ eq_(connection.execute(sel).fetchall(), [(["foo", "bar"],)])
+
+ @_enum_combinations
+ @testing.requires.python3
+ def test_array_of_enums_native_roundtrip(
+ self, array_of_enum_fixture, connection, array_cls, enum_cls
+ ):
+ tbl, MyEnum = array_of_enum_fixture(array_cls, enum_cls)
- self.metadata.drop_all(connection)
+ connection.execute(tbl.insert(), {"pyenum_col": [MyEnum.a]})
+ sel = select(tbl.c.pyenum_col).order_by(tbl.c.id.desc())
+ eq_(connection.scalar(sel), [MyEnum.a])
class ArrayJSON(fixtures.TestBase):
("clone",), ("pickle",), ("conv_to_unique"), ("none"), argnames="meth"
)
@testing.combinations(
- ("name with space",), ("name with [brackets]",), argnames="name"
+ ("name with space",),
+ ("name with [brackets]",),
+ ("name with~~tildes~~",),
+ argnames="name",
)
def test_bindparam_key_proc_for_copies(self, meth, name):
r"""test :ticket:`6249`.
Currently, the bind key reg is::
- re.sub(r"[%\(\) \$]+", "_", body).strip("_")
+ re.sub(r"[%\(\) \$\[\]]", "_", name)
and the compiler postcompile reg is::
expr.right.unique = False
expr.right._convert_to_unique()
- token = re.sub(r"[%\(\) \$]+", "_", name).strip("_")
+ token = re.sub(r"[%\(\) \$\[\]]", "_", name)
+
self.assert_compile(
expr,
'"%(name)s" IN (:%(token)s_1_1, '
"test_table WHERE test_table.y = lower(:y_1)",
)
+ def test_in_binds(self):
+ table = self._fixture()
+
+ self.assert_compile(
+ select(table).where(
+ table.c.y.in_(["hi", "there", "some", "expr"])
+ ),
+ "SELECT test_table.x, lower(test_table.y) AS y FROM "
+ "test_table WHERE test_table.y IN "
+ "([POSTCOMPILE_y_1~~lower(~~REPL~~)~~])",
+ render_postcompile=False,
+ )
+
+ self.assert_compile(
+ select(table).where(
+ table.c.y.in_(["hi", "there", "some", "expr"])
+ ),
+ "SELECT test_table.x, lower(test_table.y) AS y FROM "
+ "test_table WHERE test_table.y IN "
+ "(lower(:y_1_1), lower(:y_1_2), lower(:y_1_3), lower(:y_1_4))",
+ render_postcompile=True,
+ )
+
def test_dialect(self):
table = self._fixture()
dialect = self._dialect_level_fixture()