]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- [feature] postgresql.ARRAY features an optional
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 24 Apr 2012 20:44:53 +0000 (16:44 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 24 Apr 2012 20:44:53 +0000 (16:44 -0400)
"dimension" argument, will assign a specific
number of dimensions to the array which will
render in DDL as ARRAY[][]..., also improves
performance of bind/result processing.
[ticket:2441]

CHANGES
lib/sqlalchemy/dialects/postgresql/base.py
test/dialect/test_postgresql.py

diff --git a/CHANGES b/CHANGES
index 113d54db65de99b46d9e3ac55956f543ea3dbbb8..65599b1390f0e2672294f565d806a965e51ee05f 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -272,8 +272,14 @@ CHANGES
     with_lockmode("read_nowait").
     These emit "FOR SHARE" and "FOR SHARE NOWAIT",
     respectively.  Courtesy Diana Clarke 
-    [ticket:2445]
-    also in 0.7.7.
+    [ticket:2445] Also in 0.7.7.
+
+  - [feature] postgresql.ARRAY features an optional
+    "dimension" argument, will assign a specific
+    number of dimensions to the array which will
+    render in DDL as ARRAY[][]..., also improves
+    performance of bind/result processing.
+    [ticket:2441]
 
 - mysql
   - [bug] Fixed bug whereby column name inside 
index d47b9e7579de0ad14bd7087e352db9582d006982..c3ff73fa16146001f80b2ed45c29110f8b390477 100644 (file)
@@ -329,7 +329,7 @@ class ARRAY(sqltypes.Concatenable, sqltypes.TypeEngine):
     """
     __visit_name__ = 'ARRAY'
 
-    def __init__(self, item_type, as_tuple=False):
+    def __init__(self, item_type, as_tuple=False, dimensions=None):
         """Construct an ARRAY.
 
         E.g.::
@@ -349,6 +349,14 @@ class ARRAY(sqltypes.Concatenable, sqltypes.TypeEngine):
           as psycopg2 return lists by default. When tuples are
           returned, the results are hashable.
 
+        :param dimensions: if non-None, the ARRAY will assume a fixed
+         number of dimensions.  This will cause the DDL emitted for this
+         ARRAY to include the exact number of bracket clauses ``[]``,
+         and will also optimize the performance of the type overall. 
+         Note that PG arrays are always implicitly "non-dimensioned",
+         meaning they can store any number of dimensions no matter how
+         they were declared.
+
         """
         if isinstance(item_type, ARRAY):
             raise ValueError("Do not nest ARRAY types; ARRAY(basetype) "
@@ -357,58 +365,59 @@ class ARRAY(sqltypes.Concatenable, sqltypes.TypeEngine):
             item_type = item_type()
         self.item_type = item_type
         self.as_tuple = as_tuple
+        self.dimensions = dimensions
 
     def compare_values(self, x, y):
         return x == y
 
-    def bind_processor(self, dialect):
-        item_proc = self.item_type.dialect_impl(dialect).bind_processor(dialect)
-        if item_proc:
-            def convert_item(item):
-                if isinstance(item, (list, tuple)):
-                    return [convert_item(child) for child in item]
-                else:
-                    return item_proc(item)
+    def _proc_array(self, arr, itemproc, dim, collection):
+        if dim == 1 or (
+                    dim is None and
+                    (not arr or not isinstance(arr[0], (list, tuple)))
+                ):
+            if itemproc:
+                return collection(itemproc(x) for x in arr)
+            else:
+                return collection(arr)
         else:
-            def convert_item(item):
-                if isinstance(item, (list, tuple)):
-                    return [convert_item(child) for child in item]
-                else:
-                    return item
+            return collection(
+                    self._proc_array(
+                            x, itemproc, 
+                            dim - 1 if dim is not None else None, 
+                            collection) 
+                    for x in arr
+                )
+
+    def bind_processor(self, dialect):
+        item_proc = self.item_type.\
+                        dialect_impl(dialect).\
+                        bind_processor(dialect)
         def process(value):
             if value is None:
                 return value
-            return [convert_item(item) for item in value]
+            else:
+                return self._proc_array(
+                            value, 
+                            item_proc, 
+                            self.dimensions, 
+                            list)
         return process
 
     def result_processor(self, dialect, coltype):
-        item_proc = self.item_type.dialect_impl(dialect).result_processor(dialect, coltype)
-        if item_proc:
-            def convert_item(item):
-                if isinstance(item, list):
-                    r = [convert_item(child) for child in item]
-                    if self.as_tuple:
-                        r = tuple(r)
-                    return r
-                else:
-                    return item_proc(item)
-        else:
-            def convert_item(item):
-                if isinstance(item, list):
-                    r = [convert_item(child) for child in item]
-                    if self.as_tuple:
-                        r = tuple(r)
-                    return r
-                else:
-                    return item
+        item_proc = self.item_type.\
+                        dialect_impl(dialect).\
+                        result_processor(dialect, coltype)
         def process(value):
             if value is None:
                 return value
-            r = [convert_item(item) for item in value]
-            if self.as_tuple:
-                r = tuple(r)
-            return r
+            else:
+                return self._proc_array(
+                            value, 
+                            item_proc, 
+                            self.dimensions, 
+                            tuple if self.as_tuple else list)
         return process
+
 PGArray = ARRAY
 
 class ENUM(sqltypes.Enum):
@@ -841,7 +850,9 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
         return "BYTEA"
 
     def visit_ARRAY(self, type_):
-        return self.process(type_.item_type) + '[]'
+        return self.process(type_.item_type) + ('[]' * (type_.dimensions 
+                                                if type_.dimensions 
+                                                is not None else 1))
 
 
 class PGIdentifierPreparer(compiler.IdentifierPreparer):
index 769f18ce9a6d0c2cfefb8ce4bb80ccc7a3b0cf19..94eb2fe6ca4c28aa1ea1a583db5340300e6b3b35 100644 (file)
@@ -472,10 +472,10 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL):
     def test_generate_multiple(self):
         """Test that the same enum twice only generates once
         for the create_all() call, without using checkfirst.
-        
+
         A 'memo' collection held by the DDL runner
         now handles this.
-        
+
         """
         metadata = self.metadata
 
@@ -1920,10 +1920,32 @@ class ArrayTest(fixtures.TestBase, AssertsExecutionResults):
     def setup_class(cls):
         global metadata, arrtable
         metadata = MetaData(testing.db)
-        arrtable = Table('arrtable', metadata, Column('id', Integer,
-                         primary_key=True), Column('intarr',
-                         postgresql.ARRAY(Integer)), Column('strarr',
-                         postgresql.ARRAY(Unicode()), nullable=False))
+
+        class ProcValue(TypeDecorator):
+            impl = postgresql.ARRAY(Integer, dimensions=2)
+
+            def process_bind_param(self, value, dialect):
+                if value is None:
+                    return None
+                return [
+                    [x + 5 for x in v]
+                    for v in value
+                ]
+
+            def process_result_value(self, value, dialect):
+                if value is None:
+                    return None
+                return [
+                    [x - 7 for x in v]
+                    for v in value
+                ]
+
+        arrtable = Table('arrtable', metadata, 
+                        Column('id', Integer, primary_key=True), 
+                        Column('intarr',postgresql.ARRAY(Integer)), 
+                         Column('strarr',postgresql.ARRAY(Unicode())),
+                        Column('dimarr', ProcValue)
+                    )
         metadata.create_all()
 
     def teardown(self):
@@ -1994,73 +2016,23 @@ class ArrayTest(fixtures.TestBase, AssertsExecutionResults):
         eq_(results[0]['strarr'], [u'm\xe4\xe4', u'm\xf6\xf6'])
         eq_(results[1]['strarr'], [[u'm\xe4\xe4'], [u'm\xf6\xf6']])
 
-    @testing.fails_on('postgresql+pg8000',
-                      'pg8000 has poor support for PG arrays')
-    @testing.fails_on('postgresql+zxjdbc',
-                      'zxjdbc has no support for PG arrays')
-    def test_array_mutability(self):
-
-        class Foo(object):
-            pass
-
-        footable = Table('foo', metadata, 
-                        Column('id', Integer,primary_key=True), 
-                        Column('intarr', 
-                            postgresql.ARRAY(Integer, mutable=True), 
-                            nullable=True))
-        mapper(Foo, footable)
-        metadata.create_all()
-        sess = create_session()
-        foo = Foo()
-        foo.id = 1
-        foo.intarr = [1, 2, 3]
-        sess.add(foo)
-        sess.flush()
-        sess.expunge_all()
-        foo = sess.query(Foo).get(1)
-        eq_(foo.intarr, [1, 2, 3])
-        foo.intarr.append(4)
-        sess.flush()
-        sess.expunge_all()
-        foo = sess.query(Foo).get(1)
-        eq_(foo.intarr, [1, 2, 3, 4])
-        foo.intarr = []
-        sess.flush()
-        sess.expunge_all()
-        eq_(foo.intarr, [])
-        foo.intarr = None
-        sess.flush()
-        sess.expunge_all()
-        eq_(foo.intarr, None)
-
-        # Errors in r4217:
-
-        foo = Foo()
-        foo.id = 2
-        sess.add(foo)
-        sess.flush()
-
     @testing.fails_on('+zxjdbc',
                       "Can't infer the SQL type to use for an instance "
                       "of org.python.core.PyList.")
     @testing.provide_metadata
     def test_tuple_flag(self):
         metadata = self.metadata
-        assert_raises_message(
-            exc.ArgumentError, 
-            "mutable must be set to False if as_tuple is True.",
-            postgresql.ARRAY, Integer, mutable=True, 
-                as_tuple=True)
 
         t1 = Table('t1', metadata,
             Column('id', Integer, primary_key=True),
-            Column('data', postgresql.ARRAY(String(5), as_tuple=True, mutable=False)),
-            Column('data2', postgresql.ARRAY(Numeric(asdecimal=False), as_tuple=True, mutable=False)),
+            Column('data', postgresql.ARRAY(String(5), as_tuple=True)),
+            Column('data2', postgresql.ARRAY(Numeric(asdecimal=False), as_tuple=True)),
         )
         metadata.create_all()
         testing.db.execute(t1.insert(), id=1, data=["1","2","3"], data2=[5.4, 5.6])
         testing.db.execute(t1.insert(), id=2, data=["4", "5", "6"], data2=[1.0])
-        testing.db.execute(t1.insert(), id=3, data=[["4", "5"], ["6", "7"]], data2=[[5.4, 5.6], [1.0, 1.1]])
+        testing.db.execute(t1.insert(), id=3, data=[["4", "5"], ["6", "7"]], 
+                        data2=[[5.4, 5.6], [1.0, 1.1]])
 
         r = testing.db.execute(t1.select().order_by(t1.c.id)).fetchall()
         eq_(
@@ -2077,7 +2049,12 @@ class ArrayTest(fixtures.TestBase, AssertsExecutionResults):
             set([('1', '2', '3'), ('4', '5', '6'), (('4', '5'), ('6', '7'))])
         )
 
-
+    def test_dimension(self):
+        testing.db.execute(arrtable.insert(), dimarr=[[1, 2, 3], [4,5, 6]])
+        eq_(
+            testing.db.scalar(select([arrtable.c.dimarr])),
+            [[-1, 0, 1], [2, 3, 4]]
+        )
 
 class TimestampTest(fixtures.TestBase, AssertsExecutionResults):
     __only_on__ = 'postgresql'