]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Pass desired array type from pg.array_agg to functions.array_agg
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 22 Aug 2018 15:13:54 +0000 (11:13 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 22 Aug 2018 15:13:54 +0000 (11:13 -0400)
Fixed the :func:`.postgresql.array_agg` function, which is a slightly
altered version of the usual :func:`.functions.array_agg` function, to also
accept an incoming "type" argument without forcing an ARRAY around it,
essentially the same thing that was fixed for the generic function in 1.1
in :ticket:`4107`.

Fixes: #4324
Change-Id: I399a29f59c945a217cdd22c65ff0325edea8ea65

doc/build/changelog/unreleased_12/4324.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/ext.py
lib/sqlalchemy/sql/functions.py
test/dialect/postgresql/test_compiler.py

diff --git a/doc/build/changelog/unreleased_12/4324.rst b/doc/build/changelog/unreleased_12/4324.rst
new file mode 100644 (file)
index 0000000..5f3e879
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, postgresql
+    :tickets: 4324
+
+    Fixed the :func:`.postgresql.array_agg` function, which is a slightly
+    altered version of the usual :func:`.functions.array_agg` function, to also
+    accept an incoming "type" argument without forcing an ARRAY around it,
+    essentially the same thing that was fixed for the generic function in 1.1
+    in :ticket:`4107`.
index 20ed0fc8d625d5a50671c6867e46a56b532a789a..71fb3cc5bc5724f5d0f51d4edb1c440839409584 100644 (file)
@@ -209,10 +209,11 @@ static/sql-createtable.html#SQL-CREATETABLE-EXCLUDE
 def array_agg(*arg, **kw):
     """PostgreSQL-specific form of :class:`.array_agg`, ensures
     return type is :class:`.postgresql.ARRAY` and not
-    the plain :class:`.types.ARRAY`.
+    the plain :class:`.types.ARRAY`, unless an explicit ``type_``
+    is passed.
 
     .. versionadded:: 1.1
 
     """
-    kw['type_'] = ARRAY(functions._type_from_args(arg))
+    kw['_default_array_type'] = ARRAY
     return functions.func.array_agg(*arg, **kw)
index 27d030d4ff4dc3fc6fee1cf4f71855ede2d593a2..5cea7750a73f248b4a067b59bef2fc1148406d69 100644 (file)
@@ -793,13 +793,14 @@ class array_agg(GenericFunction):
     def __init__(self, *args, **kwargs):
         args = [_literal_as_binds(c) for c in args]
 
+        default_array_type = kwargs.pop('_default_array_type', sqltypes.ARRAY)
         if 'type_' not in kwargs:
 
             type_from_args = _type_from_args(args)
             if isinstance(type_from_args, sqltypes.ARRAY):
                 kwargs['type_'] = type_from_args
             else:
-                kwargs['type_'] = sqltypes.ARRAY(type_from_args)
+                kwargs['type_'] = default_array_type(type_from_args)
         kwargs['_parsed_args'] = args
         super(array_agg, self).__init__(*args, **kwargs)
 
index 66764fcc21f5c1b72c5f6269da67a34a8d166c93..5c4d72c6703220a683701ffb9ff31e4b011557cc 100644 (file)
@@ -7,6 +7,7 @@ from sqlalchemy import testing
 from sqlalchemy import Sequence, Table, Column, Integer, update, String,\
     func, MetaData, Enum, Index, and_, delete, select, cast, text, \
     Text, null
+from sqlalchemy import types as sqltypes
 from sqlalchemy.dialects.postgresql import ExcludeConstraint, array
 from sqlalchemy import exc, schema
 from sqlalchemy.dialects import postgresql
@@ -16,6 +17,8 @@ from sqlalchemy.sql import table, column, operators, literal_column
 from sqlalchemy.sql import util as sql_util
 from sqlalchemy.util import u, OrderedDict
 from sqlalchemy.dialects.postgresql import aggregate_order_by, insert
+from sqlalchemy.dialects.postgresql import array_agg as pg_array_agg
+from sqlalchemy.dialects.postgresql import ARRAY as PG_ARRAY
 
 
 class SequenceTest(fixtures.TestBase, AssertsCompiledSQL):
@@ -1097,6 +1100,40 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "AS string_agg_1 FROM table1"
         )
 
+    def test_pg_array_agg_implicit_pg_array(self):
+
+        expr = pg_array_agg(column('data', Integer))
+        assert isinstance(expr.type, PG_ARRAY)
+        is_(expr.type.item_type._type_affinity, Integer)
+
+    def test_pg_array_agg_uses_base_array(self):
+
+        expr = pg_array_agg(column('data', sqltypes.ARRAY(Integer)))
+        assert isinstance(expr.type, sqltypes.ARRAY)
+        assert not isinstance(expr.type, PG_ARRAY)
+        is_(expr.type.item_type._type_affinity, Integer)
+
+    def test_pg_array_agg_uses_pg_array(self):
+
+        expr = pg_array_agg(column('data', PG_ARRAY(Integer)))
+        assert isinstance(expr.type, PG_ARRAY)
+        is_(expr.type.item_type._type_affinity, Integer)
+
+    def test_pg_array_agg_explicit_base_array(self):
+
+        expr = pg_array_agg(column(
+            'data', sqltypes.ARRAY(Integer)), type_=sqltypes.ARRAY(Integer))
+        assert isinstance(expr.type, sqltypes.ARRAY)
+        assert not isinstance(expr.type, PG_ARRAY)
+        is_(expr.type.item_type._type_affinity, Integer)
+
+    def test_pg_array_agg_explicit_pg_array(self):
+
+        expr = pg_array_agg(column(
+            'data', sqltypes.ARRAY(Integer)), type_=PG_ARRAY(Integer))
+        assert isinstance(expr.type, PG_ARRAY)
+        is_(expr.type.item_type._type_affinity, Integer)
+
     def test_aggregate_order_by_adapt(self):
         m = MetaData()
         table = Table('table1', m, Column('a', Integer), Column('b', Integer))