From: Federico Caselli Date: Wed, 10 Mar 2021 22:54:52 +0000 (+0100) Subject: CAST the elements in ARRAYs when using psycopg2 X-Git-Tag: rel_1_3_24~6 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=65da8070e273ee2ea8ea71ba208d8645633b3fa0;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git CAST the elements in ARRAYs when using psycopg2 Adjusted the psycopg2 dialect to emit an explicit PostgreSQL-style cast for bound parameters that contain ARRAY elements. This allows the full range of datatypes to function correctly within arrays. The asyncpg dialect already generated these internal casts in the final statement. This also includes support for array slice updates as well as the PostgreSQL-specific :meth:`_postgresql.ARRAY.contains` method. Fixes: #6023 Change-Id: Ia7519ac4371a635f05ac69a3a4d0f4e6d2f04cad (cherry picked from commit dfa1d3b28f1a0abf1e11c76a94f7a65bf98d29af) --- diff --git a/doc/build/changelog/unreleased_13/6023.rst b/doc/build/changelog/unreleased_13/6023.rst new file mode 100644 index 0000000000..2cfe885678 --- /dev/null +++ b/doc/build/changelog/unreleased_13/6023.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, types, postgresql + :tickets: 6023 + + Adjusted the psycopg2 dialect to emit an explicit PostgreSQL-style cast for + bound parameters that contain ARRAY elements. This allows the full range of + datatypes to function correctly within arrays. The asyncpg dialect already + generated these internal casts in the final statement. This also includes + support for array slice updates as well as the PostgreSQL-specific + :meth:`_postgresql.ARRAY.contains` method. \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py index 07bfbbbcea..4ad3b21345 100644 --- a/lib/sqlalchemy/dialects/postgresql/array.py +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -313,12 +313,6 @@ class ARRAY(sqltypes.ARRAY): for x in arr ) - @util.memoized_property - def _require_cast(self): - return self._against_native_enum or isinstance( - self.item_type, sqltypes.JSON - ) - @util.memoized_property def _against_native_enum(self): return ( @@ -327,10 +321,7 @@ class ARRAY(sqltypes.ARRAY): ) def bind_expression(self, bindvalue): - if self._require_cast: - return expression.cast(bindvalue, self) - else: - return bindvalue + return bindvalue def bind_processor(self, dialect): item_proc = self.item_type.dialect_impl(dialect).bind_processor( diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index 2c5a934db4..9b75e25dcc 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -499,6 +499,7 @@ from ... import processors from ... import types as sqltypes from ... import util from ...engine import result as _result +from ...sql import elements from ...util import collections_abc try: @@ -644,7 +645,20 @@ class PGExecutionContext_psycopg2(PGExecutionContext): class PGCompiler_psycopg2(PGCompiler): - pass + def visit_bindparam(self, bindparam, skip_bind_expression=False, **kw): + + text = super(PGCompiler_psycopg2, self).visit_bindparam( + bindparam, skip_bind_expression=skip_bind_expression, **kw + ) + # note that if the type has a bind_expression(), we will get a + # double compile here + if not skip_bind_expression and bindparam.type._is_array: + text += "::%s" % ( + elements.TypeClause(bindparam.type)._compiler_dispatch( + self, skip_bind_expression=skip_bind_expression, **kw + ), + ) + return text class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer): diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index efa0e5d43c..794801bd56 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -731,6 +731,12 @@ def _get_stmt_parameters_params( elements.BindParameter(None, v, type_=k.type), **kw ) else: + if v._is_bind_parameter and v.type._isnull: + # either unique parameter, or other bound parameters that + # were passed in directly + # set type to that of the column unconditionally + v = v._with_binary_element_type(k.type) + v = compiler.process(v.self_group(), **kw) values.append((k, v)) diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 3a195e1ec7..d47d47cfbb 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -2552,6 +2552,8 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): __visit_name__ = "ARRAY" + _is_array = True + zero_indexes = False """If True, Python zero-based indexes should be interpreted as one-based on the SQL expression side.""" diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 9f279a3e94..a29e222cbc 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -45,6 +45,7 @@ class TypeEngine(Visitable): _sqla_type = True _isnull = False + _is_array = False class Comparator(operators.ColumnOperators): """Base class for custom comparison operations defined at the diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index 86e4a913da..998c31a514 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -85,6 +85,11 @@ def combinations(*comb, **kw): return _fixture_functions.combinations(*comb, **kw) +def combinations_list(arg_iterable, **kw): + "As combination, but takes a single iterable" + return combinations(*arg_iterable, **kw) + + def fixture(*arg, **kw): return _fixture_functions.fixture(*arg, **kw) diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index ad4ebb6565..a5d4968b41 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -362,7 +362,9 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): "i": lambda obj: obj, "r": repr, "s": str, - "n": operator.attrgetter("__name__"), + "n": lambda obj: obj.__name__ + if hasattr(obj, "__name__") + else type(obj).__name__, } def combinations(self, *arg_sets, **kw): diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py index d4719797b1..4b96d32e5e 100644 --- a/lib/sqlalchemy/testing/schema.py +++ b/lib/sqlalchemy/testing/schema.py @@ -5,10 +5,13 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +import sys + from . import config from . import exclusions from .. import event from .. import schema +from ..util import OrderedDict __all__ = ["Table", "Column"] @@ -114,3 +117,41 @@ def _truncate_name(dialect, name): ) else: return name + + +def pep435_enum(name): + # Implements PEP 435 in the minimal fashion needed by SQLAlchemy + __members__ = OrderedDict() + + def __init__(self, name, value, alias=None): + self.name = name + self.value = value + self.__members__[name] = self + value_to_member[value] = self + setattr(self.__class__, name, self) + if alias: + self.__members__[alias] = self + setattr(self.__class__, alias, self) + + value_to_member = {} + + @classmethod + def get(cls, value): + return value_to_member[value] + + someenum = type( + name, + (object,), + {"__members__": __members__, "__init__": __init__, "get": get}, + ) + + # getframe() trick for pickling I don't understand courtesy + # Python namedtuple() + try: + module = sys._getframe(1).f_globals.get("__name__", "__main__") + except (AttributeError, ValueError): + pass + if module is not None: + someenum.__module__ = module + + return someenum diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index e71156f442..263849e784 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -1,5 +1,4 @@ # coding: utf-8 - from sqlalchemy import and_ from sqlalchemy import cast from sqlalchemy import Column @@ -31,6 +30,8 @@ from sqlalchemy.dialects.postgresql import array_agg as pg_array_agg from sqlalchemy.dialects.postgresql import ExcludeConstraint from sqlalchemy.dialects.postgresql import insert from sqlalchemy.dialects.postgresql import TSRANGE +from sqlalchemy.dialects.postgresql.base import PGDialect +from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2 from sqlalchemy.orm import aliased from sqlalchemy.orm import mapper from sqlalchemy.orm import Session @@ -1311,13 +1312,28 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ) self.assert_compile( - c.contains([1]), "x @> %(x_1)s", checkparams={"x_1": [1]} + c.contains([1]), + "x @> %(x_1)s::INTEGER[]", + checkparams={"x_1": [1]}, + dialect=PGDialect_psycopg2(), + ) + self.assert_compile( + c.contained_by([2]), + "x <@ %(x_1)s::INTEGER[]", + checkparams={"x_1": [2]}, + dialect=PGDialect_psycopg2(), ) self.assert_compile( - c.contained_by([2]), "x <@ %(x_1)s", checkparams={"x_1": [2]} + c.contained_by([2]), + "x <@ %(x_1)s", + checkparams={"x_1": [2]}, + dialect=PGDialect(), ) self.assert_compile( - c.overlap([3]), "x && %(x_1)s", checkparams={"x_1": [3]} + c.overlap([3]), + "x && %(x_1)s::INTEGER[]", + checkparams={"x_1": [3]}, + dialect=PGDialect_psycopg2(), ) self.assert_compile( postgresql.Any(4, c), @@ -1365,7 +1381,8 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): checkparams={"param_1": 7}, ) - def _test_array_zero_indexes(self, zero_indexes): + @testing.combinations((True,), (False,)) + def test_array_zero_indexes(self, zero_indexes): c = Column("x", postgresql.ARRAY(Integer, zero_indexes=zero_indexes)) add_one = 1 if zero_indexes else 0 @@ -1403,12 +1420,6 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): }, ) - def test_array_zero_indexes_true(self): - self._test_array_zero_indexes(True) - - def test_array_zero_indexes_false(self): - self._test_array_zero_indexes(False) - def test_array_literal_type(self): isinstance(postgresql.array([1, 2]).type, postgresql.ARRAY) is_(postgresql.array([1, 2]).type.item_type._type_affinity, Integer) @@ -1536,6 +1547,15 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "%(param_2)s, %(param_3)s])", ) + def test_update_array(self): + m = MetaData() + t = Table("t", m, Column("data", postgresql.ARRAY(Integer))) + self.assert_compile( + t.update().values({t.c.data: [1, 3, 4]}), + "UPDATE t SET data=%(data)s::INTEGER[]", + checkparams={"data": [1, 3, 4]}, + ) + def test_update_array_element(self): m = MetaData() t = Table("t", m, Column("data", postgresql.ARRAY(Integer))) @@ -1548,10 +1568,22 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def test_update_array_slice(self): m = MetaData() t = Table("t", m, Column("data", postgresql.ARRAY(Integer))) + + # psycopg2-specific, has a cast + self.assert_compile( + t.update().values({t.c.data[2:5]: [2, 3, 4]}), + "UPDATE t SET data[%(data_1)s:%(data_2)s]=" + "%(param_1)s::INTEGER[]", + checkparams={"param_1": [2, 3, 4], "data_2": 5, "data_1": 2}, + dialect=PGDialect_psycopg2(), + ) + + # default dialect does not, as DBAPIs may be doing this for us self.assert_compile( - t.update().values({t.c.data[2:5]: 2}), - "UPDATE t SET data[%(data_1)s:%(data_2)s]=%(param_1)s", - checkparams={"param_1": 2, "data_2": 5, "data_1": 2}, + t.update().values({t.c.data[2:5]: [2, 3, 4]}), + "UPDATE t SET data[%s:%s]=" "%s", + checkparams={"param_1": [2, 3, 4], "data_2": 5, "data_1": 2}, + dialect=PGDialect(paramstyle="format"), ) def test_from_only(self): diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index d5921ece1d..a6658b3982 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -60,9 +60,15 @@ from sqlalchemy.testing.assertions import ComparesTables from sqlalchemy.testing.assertions import eq_ from sqlalchemy.testing.assertions import is_ from sqlalchemy.testing.assertsql import RegexSQL +from sqlalchemy.testing.schema import pep435_enum from sqlalchemy.testing.suite import test_types as suite from sqlalchemy.testing.util import round_decimal +try: + import enum +except ImportError: + enum = None + tztable = notztable = metadata = table = None @@ -1344,6 +1350,12 @@ class ArrayTest(AssertsCompiledSQL, fixtures.TestBase): is_(expr.type.item_type.__class__, element_type) +AnEnum = pep435_enum("AnEnum") +AnEnum("Foo", 1) +AnEnum("Bar", 2) +AnEnum("Baz", 3) + + class ArrayRoundTripTest(object): __only_on__ = "postgresql" @@ -1743,6 +1755,258 @@ class ArrayRoundTripTest(object): t.drop(testing.db) eq_(inspect(testing.db).get_enums(), []) + def _type_combinations(exclude_json=False): + def str_values(x): + return ["one", "two: %s" % x, "three", "four", "five"] + + def unicode_values(x): + return [ + util.u("réveillé"), + util.u("drôle"), + util.u("S’il %s" % x), + util.u("🐍 %s" % x), + util.u("« S’il vous"), + ] + + def json_values(x): + return [ + 1, + {"a": x}, + {"b": [1, 2, 3]}, + ["d", "e", "f"], + {"struct": True, "none": None}, + ] + + def binary_values(x): + return [v.encode("utf-8") for v in unicode_values(x)] + + def enum_values(x): + return [ + AnEnum.Foo, + AnEnum.Baz, + AnEnum.get(x), + AnEnum.Baz, + AnEnum.Foo, + ] + + class inet_str(str): + def __eq__(self, other): + return str(self) == str(other) + + def __ne__(self, other): + return str(self) != str(other) + + class money_str(str): + def __eq__(self, other): + comp = re.sub(r"[^\d\.]", "", other) + return float(self) == float(comp) + + def __ne__(self, other): + return not self.__eq__(other) + + elements = [ + (sqltypes.Integer, lambda x: [1, x, 3, 4, 5]), + (sqltypes.Text, str_values), + (sqltypes.String, str_values), + (sqltypes.Unicode, unicode_values), + (postgresql.JSONB, json_values), + (sqltypes.Boolean, lambda x: [False] + [True] * x), + ( + sqltypes.LargeBinary, + binary_values, + ), + ( + postgresql.BYTEA, + binary_values, + ), + ( + postgresql.INET, + lambda x: [ + inet_str("1.1.1.1"), + inet_str("{0}.{0}.{0}.{0}".format(x)), + inet_str("192.168.1.1"), + inet_str("10.1.2.25"), + inet_str("192.168.22.5"), + ], + ), + ( + postgresql.CIDR, + lambda x: [ + inet_str("10.0.0.0/8"), + inet_str("%s.0.0.0/8" % x), + inet_str("192.168.1.0/24"), + inet_str("192.168.0.0/16"), + inet_str("192.168.1.25/32"), + ], + ), + ( + sqltypes.Date, + lambda x: [ + datetime.date(2020, 5, x), + datetime.date(2020, 7, 12), + datetime.date(2018, 12, 15), + datetime.date(2009, 1, 5), + datetime.date(2021, 3, 18), + ], + ), + ( + sqltypes.DateTime, + lambda x: [ + datetime.datetime(2020, 5, x, 2, 15, 0), + datetime.datetime(2020, 7, 12, 15, 30, x), + datetime.datetime(2018, 12, 15, 3, x, 25), + datetime.datetime(2009, 1, 5, 12, 45, x), + datetime.datetime(2021, 3, 18, 17, 1, 0), + ], + ), + ( + sqltypes.Numeric, + lambda x: [ + decimal.Decimal("45.10"), + decimal.Decimal(x), + decimal.Decimal(".03242"), + decimal.Decimal("532.3532"), + decimal.Decimal("95503.23"), + ], + ), + ( + postgresql.MONEY, + lambda x: [ + money_str("2"), + money_str("%s" % (5 + x)), + money_str("50.25"), + money_str("18.99"), + money_str("15.%s" % x), + ], + testing.skip_if( + "postgresql+psycopg2", "this is a psycopg2 bug" + ), + ), + ( + postgresql.HSTORE, + lambda x: [ + {"a": "1"}, + {"b": "%s" % x}, + {"c": "3"}, + {"c": "c2"}, + {"d": "e"}, + ], + testing.requires.hstore, + ), + (sqltypes.Enum(AnEnum, native_enum=True), enum_values), + ( + sqltypes.Enum( + AnEnum, native_enum=False, create_constraint=False + ), + enum_values, + ), + ] + + if not exclude_json: + elements.extend( + [ + (sqltypes.JSON, json_values), + (postgresql.JSON, json_values), + ] + ) + + return testing.combinations(*elements, argnames="type_,gen", id_="na") + + @classmethod + def _cls_type_combinations(cls, **kw): + return ArrayRoundTripTest.__dict__["_type_combinations"](**kw) + + @testing.fixture + def metadata(self): + m = MetaData() + yield m + + m.drop_all(testing.db) + + @testing.fixture + def type_specific_fixture(self, metadata, connection, type_): + meta = MetaData() + table = Table( + "foo", + meta, + Column("id", Integer), + Column("bar", self.ARRAY(type_)), + ) + + meta.create_all(connection) + + def go(gen): + connection.execute( + table.insert(), + [{"id": 1, "bar": gen(1)}, {"id": 2, "bar": gen(2)}], + ) + return table + + return go + + @_type_combinations() + def test_type_specific_value_select( + self, type_specific_fixture, connection, type_, gen + ): + table = type_specific_fixture(gen) + + rows = connection.execute( + select([table.c.bar]).order_by(table.c.id) + ).fetchall() + + eq_(rows, [(gen(1),), (gen(2),)]) + + @_type_combinations() + def test_type_specific_value_update( + self, type_specific_fixture, connection, type_, gen + ): + table = type_specific_fixture(gen) + + new_gen = gen(3) + connection.execute( + table.update().where(table.c.id == 2).values(bar=new_gen) + ) + + eq_( + new_gen, + connection.scalar(select([table.c.bar]).where(table.c.id == 2)), + ) + + @_type_combinations() + def test_type_specific_slice_update( + self, type_specific_fixture, connection, type_, gen + ): + table = type_specific_fixture(gen) + + new_gen = gen(3) + + connection.execute( + table.update() + .where(table.c.id == 2) + .values({table.c.bar[1:3]: new_gen[1:4]}) + ) + + rows = connection.execute( + select([table.c.bar]).order_by(table.c.id) + ).fetchall() + + sliced_gen = gen(2) + sliced_gen[0:3] = new_gen[1:4] + + eq_(rows, [(gen(1),), (sliced_gen,)]) + + @_type_combinations(exclude_json=True) + def test_type_specific_value_delete( + self, type_specific_fixture, connection, type_, gen + ): + table = type_specific_fixture(gen) + + new_gen = gen(2) + + connection.execute(table.delete().where(table.c.bar == new_gen)) + + eq_(connection.scalar(select([func.count(table.c.id)])), 1) + class CoreArrayRoundTripTest( ArrayRoundTripTest, fixtures.TablesTest, AssertsExecutionResults @@ -1756,6 +2020,23 @@ class PGArrayRoundTripTest( ): ARRAY = postgresql.ARRAY + @ArrayRoundTripTest._cls_type_combinations(exclude_json=True) + def test_type_specific_contains( + self, type_specific_fixture, connection, type_, gen + ): + table = type_specific_fixture(gen) + + connection.execute( + table.insert(), + [{"id": 1, "bar": gen(1)}, {"id": 2, "bar": gen(2)}], + ) + + id_, value = connection.execute( + select([table]).where(table.c.bar.contains(gen(1))) + ).first() + eq_(id_, 1) + eq_(value, gen(1)) + @testing.combinations((set,), (list,), (lambda elem: (x for x in elem),)) def test_undim_array_contains_typed_exec(self, struct): arrtable = self.tables.arrtable diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 55e9b0f9a5..6782f262b6 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -81,10 +81,10 @@ from sqlalchemy.testing import is_not from sqlalchemy.testing import mock from sqlalchemy.testing import pickleable from sqlalchemy.testing.schema import Column +from sqlalchemy.testing.schema import pep435_enum from sqlalchemy.testing.schema import Table from sqlalchemy.testing.util import picklers from sqlalchemy.testing.util import round_decimal -from sqlalchemy.util import OrderedDict def _all_dialect_modules(): @@ -1445,21 +1445,7 @@ class UnicodeTest(fixtures.TestBase): class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): __backend__ = True - class SomeEnum(object): - # Implements PEP 435 in the minimal fashion needed by SQLAlchemy - __members__ = OrderedDict() - - def __init__(self, name, value, alias=None): - self.name = name - self.value = value - self.__members__[name] = self - setattr(self.__class__, name, self) - if alias: - self.__members__[alias] = self - setattr(self.__class__, alias, self) - - class SomeOtherEnum(SomeEnum): - __members__ = OrderedDict() + SomeEnum = pep435_enum("SomeEnum") one = SomeEnum("one", 1) two = SomeEnum("two", 2) @@ -1467,6 +1453,8 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): a_member = SomeEnum("AMember", "a") b_member = SomeEnum("BMember", "b") + SomeOtherEnum = pep435_enum("SomeOtherEnum") + other_one = SomeOtherEnum("one", 1) other_two = SomeOtherEnum("two", 2) other_three = SomeOtherEnum("three", 3)