--- /dev/null
+.. change::
+ :tags: bug, types, postgresql
+ :tickets: 6023
+
+ Adjusted the psycopg2 dialect to emit an explicit PostgreSQL-style cast for
+ bound parameters that contain ARRAY elements. This allows the full range of
+ datatypes to function correctly within arrays. The asyncpg dialect already
+ generated these internal casts in the final statement. This also includes
+ support for array slice updates as well as the PostgreSQL-specific
+ :meth:`_postgresql.ARRAY.contains` method.
\ No newline at end of file
for x in arr
)
- @util.memoized_property
- def _require_cast(self):
- return self._against_native_enum or isinstance(
- self.item_type, sqltypes.JSON
- )
-
@util.memoized_property
def _against_native_enum(self):
return (
)
def bind_expression(self, bindvalue):
- if self._require_cast:
- return expression.cast(bindvalue, self)
- else:
- return bindvalue
+ return bindvalue
def bind_processor(self, dialect):
item_proc = self.item_type.dialect_impl(dialect).bind_processor(
from ... import types as sqltypes
from ... import util
from ...engine import result as _result
+from ...sql import elements
from ...util import collections_abc
try:
class PGCompiler_psycopg2(PGCompiler):
- pass
+ def visit_bindparam(self, bindparam, skip_bind_expression=False, **kw):
+
+ text = super(PGCompiler_psycopg2, self).visit_bindparam(
+ bindparam, skip_bind_expression=skip_bind_expression, **kw
+ )
+ # note that if the type has a bind_expression(), we will get a
+ # double compile here
+ if not skip_bind_expression and bindparam.type._is_array:
+ text += "::%s" % (
+ elements.TypeClause(bindparam.type)._compiler_dispatch(
+ self, skip_bind_expression=skip_bind_expression, **kw
+ ),
+ )
+ return text
class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer):
elements.BindParameter(None, v, type_=k.type), **kw
)
else:
+ if v._is_bind_parameter and v.type._isnull:
+ # either unique parameter, or other bound parameters that
+ # were passed in directly
+ # set type to that of the column unconditionally
+ v = v._with_binary_element_type(k.type)
+
v = compiler.process(v.self_group(), **kw)
values.append((k, v))
__visit_name__ = "ARRAY"
+ _is_array = True
+
zero_indexes = False
"""If True, Python zero-based indexes should be interpreted as one-based
on the SQL expression side."""
_sqla_type = True
_isnull = False
+ _is_array = False
class Comparator(operators.ColumnOperators):
"""Base class for custom comparison operations defined at the
return _fixture_functions.combinations(*comb, **kw)
+def combinations_list(arg_iterable, **kw):
+ "As combination, but takes a single iterable"
+ return combinations(*arg_iterable, **kw)
+
+
def fixture(*arg, **kw):
return _fixture_functions.fixture(*arg, **kw)
"i": lambda obj: obj,
"r": repr,
"s": str,
- "n": operator.attrgetter("__name__"),
+ "n": lambda obj: obj.__name__
+ if hasattr(obj, "__name__")
+ else type(obj).__name__,
}
def combinations(self, *arg_sets, **kw):
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
+import sys
+
from . import config
from . import exclusions
from .. import event
from .. import schema
+from ..util import OrderedDict
__all__ = ["Table", "Column"]
)
else:
return name
+
+
+def pep435_enum(name):
+ # Implements PEP 435 in the minimal fashion needed by SQLAlchemy
+ __members__ = OrderedDict()
+
+ def __init__(self, name, value, alias=None):
+ self.name = name
+ self.value = value
+ self.__members__[name] = self
+ value_to_member[value] = self
+ setattr(self.__class__, name, self)
+ if alias:
+ self.__members__[alias] = self
+ setattr(self.__class__, alias, self)
+
+ value_to_member = {}
+
+ @classmethod
+ def get(cls, value):
+ return value_to_member[value]
+
+ someenum = type(
+ name,
+ (object,),
+ {"__members__": __members__, "__init__": __init__, "get": get},
+ )
+
+ # getframe() trick for pickling I don't understand courtesy
+ # Python namedtuple()
+ try:
+ module = sys._getframe(1).f_globals.get("__name__", "__main__")
+ except (AttributeError, ValueError):
+ pass
+ if module is not None:
+ someenum.__module__ = module
+
+ return someenum
# coding: utf-8
-
from sqlalchemy import and_
from sqlalchemy import cast
from sqlalchemy import Column
from sqlalchemy.dialects.postgresql import ExcludeConstraint
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.dialects.postgresql import TSRANGE
+from sqlalchemy.dialects.postgresql.base import PGDialect
+from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
from sqlalchemy.orm import aliased
from sqlalchemy.orm import mapper
from sqlalchemy.orm import Session
)
self.assert_compile(
- c.contains([1]), "x @> %(x_1)s", checkparams={"x_1": [1]}
+ c.contains([1]),
+ "x @> %(x_1)s::INTEGER[]",
+ checkparams={"x_1": [1]},
+ dialect=PGDialect_psycopg2(),
+ )
+ self.assert_compile(
+ c.contained_by([2]),
+ "x <@ %(x_1)s::INTEGER[]",
+ checkparams={"x_1": [2]},
+ dialect=PGDialect_psycopg2(),
)
self.assert_compile(
- c.contained_by([2]), "x <@ %(x_1)s", checkparams={"x_1": [2]}
+ c.contained_by([2]),
+ "x <@ %(x_1)s",
+ checkparams={"x_1": [2]},
+ dialect=PGDialect(),
)
self.assert_compile(
- c.overlap([3]), "x && %(x_1)s", checkparams={"x_1": [3]}
+ c.overlap([3]),
+ "x && %(x_1)s::INTEGER[]",
+ checkparams={"x_1": [3]},
+ dialect=PGDialect_psycopg2(),
)
self.assert_compile(
postgresql.Any(4, c),
checkparams={"param_1": 7},
)
- def _test_array_zero_indexes(self, zero_indexes):
+ @testing.combinations((True,), (False,))
+ def test_array_zero_indexes(self, zero_indexes):
c = Column("x", postgresql.ARRAY(Integer, zero_indexes=zero_indexes))
add_one = 1 if zero_indexes else 0
},
)
- def test_array_zero_indexes_true(self):
- self._test_array_zero_indexes(True)
-
- def test_array_zero_indexes_false(self):
- self._test_array_zero_indexes(False)
-
def test_array_literal_type(self):
isinstance(postgresql.array([1, 2]).type, postgresql.ARRAY)
is_(postgresql.array([1, 2]).type.item_type._type_affinity, Integer)
"%(param_2)s, %(param_3)s])",
)
+ def test_update_array(self):
+ m = MetaData()
+ t = Table("t", m, Column("data", postgresql.ARRAY(Integer)))
+ self.assert_compile(
+ t.update().values({t.c.data: [1, 3, 4]}),
+ "UPDATE t SET data=%(data)s::INTEGER[]",
+ checkparams={"data": [1, 3, 4]},
+ )
+
def test_update_array_element(self):
m = MetaData()
t = Table("t", m, Column("data", postgresql.ARRAY(Integer)))
def test_update_array_slice(self):
m = MetaData()
t = Table("t", m, Column("data", postgresql.ARRAY(Integer)))
+
+ # psycopg2-specific, has a cast
+ self.assert_compile(
+ t.update().values({t.c.data[2:5]: [2, 3, 4]}),
+ "UPDATE t SET data[%(data_1)s:%(data_2)s]="
+ "%(param_1)s::INTEGER[]",
+ checkparams={"param_1": [2, 3, 4], "data_2": 5, "data_1": 2},
+ dialect=PGDialect_psycopg2(),
+ )
+
+ # default dialect does not, as DBAPIs may be doing this for us
self.assert_compile(
- t.update().values({t.c.data[2:5]: 2}),
- "UPDATE t SET data[%(data_1)s:%(data_2)s]=%(param_1)s",
- checkparams={"param_1": 2, "data_2": 5, "data_1": 2},
+ t.update().values({t.c.data[2:5]: [2, 3, 4]}),
+ "UPDATE t SET data[%s:%s]=" "%s",
+ checkparams={"param_1": [2, 3, 4], "data_2": 5, "data_1": 2},
+ dialect=PGDialect(paramstyle="format"),
)
def test_from_only(self):
from sqlalchemy.testing.assertions import eq_
from sqlalchemy.testing.assertions import is_
from sqlalchemy.testing.assertsql import RegexSQL
+from sqlalchemy.testing.schema import pep435_enum
from sqlalchemy.testing.suite import test_types as suite
from sqlalchemy.testing.util import round_decimal
+try:
+ import enum
+except ImportError:
+ enum = None
+
tztable = notztable = metadata = table = None
is_(expr.type.item_type.__class__, element_type)
+AnEnum = pep435_enum("AnEnum")
+AnEnum("Foo", 1)
+AnEnum("Bar", 2)
+AnEnum("Baz", 3)
+
+
class ArrayRoundTripTest(object):
__only_on__ = "postgresql"
t.drop(testing.db)
eq_(inspect(testing.db).get_enums(), [])
+ def _type_combinations(exclude_json=False):
+ def str_values(x):
+ return ["one", "two: %s" % x, "three", "four", "five"]
+
+ def unicode_values(x):
+ return [
+ util.u("réveillé"),
+ util.u("drôle"),
+ util.u("S’il %s" % x),
+ util.u("🐍 %s" % x),
+ util.u("« S’il vous"),
+ ]
+
+ def json_values(x):
+ return [
+ 1,
+ {"a": x},
+ {"b": [1, 2, 3]},
+ ["d", "e", "f"],
+ {"struct": True, "none": None},
+ ]
+
+ def binary_values(x):
+ return [v.encode("utf-8") for v in unicode_values(x)]
+
+ def enum_values(x):
+ return [
+ AnEnum.Foo,
+ AnEnum.Baz,
+ AnEnum.get(x),
+ AnEnum.Baz,
+ AnEnum.Foo,
+ ]
+
+ class inet_str(str):
+ def __eq__(self, other):
+ return str(self) == str(other)
+
+ def __ne__(self, other):
+ return str(self) != str(other)
+
+ class money_str(str):
+ def __eq__(self, other):
+ comp = re.sub(r"[^\d\.]", "", other)
+ return float(self) == float(comp)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ elements = [
+ (sqltypes.Integer, lambda x: [1, x, 3, 4, 5]),
+ (sqltypes.Text, str_values),
+ (sqltypes.String, str_values),
+ (sqltypes.Unicode, unicode_values),
+ (postgresql.JSONB, json_values),
+ (sqltypes.Boolean, lambda x: [False] + [True] * x),
+ (
+ sqltypes.LargeBinary,
+ binary_values,
+ ),
+ (
+ postgresql.BYTEA,
+ binary_values,
+ ),
+ (
+ postgresql.INET,
+ lambda x: [
+ inet_str("1.1.1.1"),
+ inet_str("{0}.{0}.{0}.{0}".format(x)),
+ inet_str("192.168.1.1"),
+ inet_str("10.1.2.25"),
+ inet_str("192.168.22.5"),
+ ],
+ ),
+ (
+ postgresql.CIDR,
+ lambda x: [
+ inet_str("10.0.0.0/8"),
+ inet_str("%s.0.0.0/8" % x),
+ inet_str("192.168.1.0/24"),
+ inet_str("192.168.0.0/16"),
+ inet_str("192.168.1.25/32"),
+ ],
+ ),
+ (
+ sqltypes.Date,
+ lambda x: [
+ datetime.date(2020, 5, x),
+ datetime.date(2020, 7, 12),
+ datetime.date(2018, 12, 15),
+ datetime.date(2009, 1, 5),
+ datetime.date(2021, 3, 18),
+ ],
+ ),
+ (
+ sqltypes.DateTime,
+ lambda x: [
+ datetime.datetime(2020, 5, x, 2, 15, 0),
+ datetime.datetime(2020, 7, 12, 15, 30, x),
+ datetime.datetime(2018, 12, 15, 3, x, 25),
+ datetime.datetime(2009, 1, 5, 12, 45, x),
+ datetime.datetime(2021, 3, 18, 17, 1, 0),
+ ],
+ ),
+ (
+ sqltypes.Numeric,
+ lambda x: [
+ decimal.Decimal("45.10"),
+ decimal.Decimal(x),
+ decimal.Decimal(".03242"),
+ decimal.Decimal("532.3532"),
+ decimal.Decimal("95503.23"),
+ ],
+ ),
+ (
+ postgresql.MONEY,
+ lambda x: [
+ money_str("2"),
+ money_str("%s" % (5 + x)),
+ money_str("50.25"),
+ money_str("18.99"),
+ money_str("15.%s" % x),
+ ],
+ testing.skip_if(
+ "postgresql+psycopg2", "this is a psycopg2 bug"
+ ),
+ ),
+ (
+ postgresql.HSTORE,
+ lambda x: [
+ {"a": "1"},
+ {"b": "%s" % x},
+ {"c": "3"},
+ {"c": "c2"},
+ {"d": "e"},
+ ],
+ testing.requires.hstore,
+ ),
+ (sqltypes.Enum(AnEnum, native_enum=True), enum_values),
+ (
+ sqltypes.Enum(
+ AnEnum, native_enum=False, create_constraint=False
+ ),
+ enum_values,
+ ),
+ ]
+
+ if not exclude_json:
+ elements.extend(
+ [
+ (sqltypes.JSON, json_values),
+ (postgresql.JSON, json_values),
+ ]
+ )
+
+ return testing.combinations(*elements, argnames="type_,gen", id_="na")
+
+ @classmethod
+ def _cls_type_combinations(cls, **kw):
+ return ArrayRoundTripTest.__dict__["_type_combinations"](**kw)
+
+ @testing.fixture
+ def metadata(self):
+ m = MetaData()
+ yield m
+
+ m.drop_all(testing.db)
+
+ @testing.fixture
+ def type_specific_fixture(self, metadata, connection, type_):
+ meta = MetaData()
+ table = Table(
+ "foo",
+ meta,
+ Column("id", Integer),
+ Column("bar", self.ARRAY(type_)),
+ )
+
+ meta.create_all(connection)
+
+ def go(gen):
+ connection.execute(
+ table.insert(),
+ [{"id": 1, "bar": gen(1)}, {"id": 2, "bar": gen(2)}],
+ )
+ return table
+
+ return go
+
+ @_type_combinations()
+ def test_type_specific_value_select(
+ self, type_specific_fixture, connection, type_, gen
+ ):
+ table = type_specific_fixture(gen)
+
+ rows = connection.execute(
+ select([table.c.bar]).order_by(table.c.id)
+ ).fetchall()
+
+ eq_(rows, [(gen(1),), (gen(2),)])
+
+ @_type_combinations()
+ def test_type_specific_value_update(
+ self, type_specific_fixture, connection, type_, gen
+ ):
+ table = type_specific_fixture(gen)
+
+ new_gen = gen(3)
+ connection.execute(
+ table.update().where(table.c.id == 2).values(bar=new_gen)
+ )
+
+ eq_(
+ new_gen,
+ connection.scalar(select([table.c.bar]).where(table.c.id == 2)),
+ )
+
+ @_type_combinations()
+ def test_type_specific_slice_update(
+ self, type_specific_fixture, connection, type_, gen
+ ):
+ table = type_specific_fixture(gen)
+
+ new_gen = gen(3)
+
+ connection.execute(
+ table.update()
+ .where(table.c.id == 2)
+ .values({table.c.bar[1:3]: new_gen[1:4]})
+ )
+
+ rows = connection.execute(
+ select([table.c.bar]).order_by(table.c.id)
+ ).fetchall()
+
+ sliced_gen = gen(2)
+ sliced_gen[0:3] = new_gen[1:4]
+
+ eq_(rows, [(gen(1),), (sliced_gen,)])
+
+ @_type_combinations(exclude_json=True)
+ def test_type_specific_value_delete(
+ self, type_specific_fixture, connection, type_, gen
+ ):
+ table = type_specific_fixture(gen)
+
+ new_gen = gen(2)
+
+ connection.execute(table.delete().where(table.c.bar == new_gen))
+
+ eq_(connection.scalar(select([func.count(table.c.id)])), 1)
+
class CoreArrayRoundTripTest(
ArrayRoundTripTest, fixtures.TablesTest, AssertsExecutionResults
):
ARRAY = postgresql.ARRAY
+ @ArrayRoundTripTest._cls_type_combinations(exclude_json=True)
+ def test_type_specific_contains(
+ self, type_specific_fixture, connection, type_, gen
+ ):
+ table = type_specific_fixture(gen)
+
+ connection.execute(
+ table.insert(),
+ [{"id": 1, "bar": gen(1)}, {"id": 2, "bar": gen(2)}],
+ )
+
+ id_, value = connection.execute(
+ select([table]).where(table.c.bar.contains(gen(1)))
+ ).first()
+ eq_(id_, 1)
+ eq_(value, gen(1))
+
@testing.combinations((set,), (list,), (lambda elem: (x for x in elem),))
def test_undim_array_contains_typed_exec(self, struct):
arrtable = self.tables.arrtable
from sqlalchemy.testing import mock
from sqlalchemy.testing import pickleable
from sqlalchemy.testing.schema import Column
+from sqlalchemy.testing.schema import pep435_enum
from sqlalchemy.testing.schema import Table
from sqlalchemy.testing.util import picklers
from sqlalchemy.testing.util import round_decimal
-from sqlalchemy.util import OrderedDict
def _all_dialect_modules():
class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
__backend__ = True
- class SomeEnum(object):
- # Implements PEP 435 in the minimal fashion needed by SQLAlchemy
- __members__ = OrderedDict()
-
- def __init__(self, name, value, alias=None):
- self.name = name
- self.value = value
- self.__members__[name] = self
- setattr(self.__class__, name, self)
- if alias:
- self.__members__[alias] = self
- setattr(self.__class__, alias, self)
-
- class SomeOtherEnum(SomeEnum):
- __members__ = OrderedDict()
+ SomeEnum = pep435_enum("SomeEnum")
one = SomeEnum("one", 1)
two = SomeEnum("two", 2)
a_member = SomeEnum("AMember", "a")
b_member = SomeEnum("BMember", "b")
+ SomeOtherEnum = pep435_enum("SomeOtherEnum")
+
other_one = SomeOtherEnum("one", 1)
other_two = SomeOtherEnum("two", 2)
other_three = SomeOtherEnum("three", 3)