]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Added "as_tuple" flag to pg ARRAY type, returns results
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 20 Oct 2010 20:17:17 +0000 (16:17 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 20 Oct 2010 20:17:17 +0000 (16:17 -0400)
as tuples instead of lists to allow hashing.

CHANGES
lib/sqlalchemy/dialects/postgresql/base.py
test/dialect/test_postgresql.py

diff --git a/CHANGES b/CHANGES
index 63761d6234baa48a8d437d38cbaefb6b3fcc7bba..04c75ac5963178d18d3e772768d379bbe7f57c3b 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -194,6 +194,10 @@ CHANGES
      boolean values for 'use_native_unicode'.  
      [ticket:1899]
 
+- postgresql
+   - Added "as_tuple" flag to ARRAY type, returns results
+     as tuples instead of lists to allow hashing.
+     
 - mssql
    - Fixed reflection bug which did not properly handle
      reflection of unknown types.  [ticket:1946]
index 89769b8c0a66f151394d01b47b544fa50658d936..03260cdb349fc5362f4afe54a5e5bcf1bc30a73f 100644 (file)
@@ -171,7 +171,7 @@ class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine):
     """
     __visit_name__ = 'ARRAY'
     
-    def __init__(self, item_type, mutable=True):
+    def __init__(self, item_type, mutable=True, as_tuple=False):
         """Construct an ARRAY.
 
         E.g.::
@@ -186,9 +186,14 @@ class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine):
           ``ARRAY(ARRAY(Integer))`` or such. The type mapping figures out on
           the fly
 
-        :param mutable: Defaults to True: specify whether lists passed to this
+        :param mutable=True: Specify whether lists passed to this
           class should be considered mutable. If so, generic copy operations
           (typically used by the ORM) will shallow-copy values.
+        
+        :param as_tuple=False: Specify whether return results should be converted
+          to tuples from lists.  DBAPIs such as psycopg2 return lists by default.
+          When tuples are returned, the results are hashable.   This flag can only
+          be set to ``True`` when ``mutable`` is set to ``False``. (new in 0.6.5)
           
         """
         if isinstance(item_type, ARRAY):
@@ -198,7 +203,12 @@ class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine):
             item_type = item_type()
         self.item_type = item_type
         self.mutable = mutable
-
+        if mutable and as_tuple:
+            raise exc.ArgumentError(
+                "mutable must be set to False if as_tuple is True."
+            )
+        self.as_tuple = as_tuple
+        
     def copy_value(self, value):
         if value is None:
             return None
@@ -224,7 +234,8 @@ class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine):
     def adapt(self, impltype):
         return impltype(
             self.item_type,
-            mutable=self.mutable
+            mutable=self.mutable,
+            as_tuple=self.as_tuple
         )
         
     def bind_processor(self, dialect):
@@ -252,19 +263,28 @@ class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine):
         if item_proc:
             def convert_item(item):
                 if isinstance(item, list):
-                    return [convert_item(child) for child in item]
+                    r = [convert_item(child) for child in item]
+                    if self.as_tuple:
+                        r = tuple(r)
+                    return r
                 else:
                     return item_proc(item)
         else:
             def convert_item(item):
                 if isinstance(item, list):
-                    return [convert_item(child) for child in item]
+                    r = [convert_item(child) for child in item]
+                    if self.as_tuple:
+                        r = tuple(r)
+                    return r
                 else:
                     return item
         def process(value):
             if value is None:
                 return value
-            return [convert_item(item) for item in value]
+            r = [convert_item(item) for item in value]
+            if self.as_tuple:
+                r = tuple(r)
+            return r
         return process
 PGArray = ARRAY
 
index 9ad46c189c5a7bc884988af4d98619f4dd230395..36aa7f2c6f973953795da5e0cb5bb7f0a1b8d257 100644 (file)
@@ -1492,8 +1492,8 @@ class ArrayTest(TestBase, AssertsExecutionResults):
         metadata = MetaData(testing.db)
         arrtable = Table('arrtable', metadata, Column('id', Integer,
                          primary_key=True), Column('intarr',
-                         postgresql.PGArray(Integer)), Column('strarr',
-                         postgresql.PGArray(Unicode()), nullable=False))
+                         postgresql.ARRAY(Integer)), Column('strarr',
+                         postgresql.ARRAY(Unicode()), nullable=False))
         metadata.create_all()
 
     def teardown(self):
@@ -1506,8 +1506,8 @@ class ArrayTest(TestBase, AssertsExecutionResults):
     def test_reflect_array_column(self):
         metadata2 = MetaData(testing.db)
         tbl = Table('arrtable', metadata2, autoload=True)
-        assert isinstance(tbl.c.intarr.type, postgresql.PGArray)
-        assert isinstance(tbl.c.strarr.type, postgresql.PGArray)
+        assert isinstance(tbl.c.intarr.type, postgresql.ARRAY)
+        assert isinstance(tbl.c.strarr.type, postgresql.ARRAY)
         assert isinstance(tbl.c.intarr.type.item_type, Integer)
         assert isinstance(tbl.c.strarr.type.item_type, String)
 
@@ -1575,7 +1575,7 @@ class ArrayTest(TestBase, AssertsExecutionResults):
 
         footable = Table('foo', metadata, Column('id', Integer,
                          primary_key=True), Column('intarr',
-                         postgresql.PGArray(Integer), nullable=True))
+                         postgresql.ARRAY(Integer), nullable=True))
         mapper(Foo, footable)
         metadata.create_all()
         sess = create_session()
@@ -1607,7 +1607,41 @@ class ArrayTest(TestBase, AssertsExecutionResults):
         foo.id = 2
         sess.add(foo)
         sess.flush()
-
+    
+    @testing.provide_metadata
+    def test_tuple_flag(self):
+        assert_raises_message(
+            exc.ArgumentError, 
+            "mutable must be set to False if as_tuple is True.",
+            postgresql.ARRAY, Integer, as_tuple=True)
+        
+        t1 = Table('t1', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('data', postgresql.ARRAY(String(5), as_tuple=True, mutable=False)),
+            Column('data2', postgresql.ARRAY(Numeric(asdecimal=False), as_tuple=True, mutable=False)),
+        )
+        metadata.create_all()
+        testing.db.execute(t1.insert(), id=1, data=["1","2","3"], data2=[5.4, 5.6])
+        testing.db.execute(t1.insert(), id=2, data=["4", "5", "6"], data2=[1.0])
+        testing.db.execute(t1.insert(), id=3, data=[["4", "5"], ["6", "7"]], data2=[[5.4, 5.6], [1.0, 1.1]])
+        
+        r = testing.db.execute(t1.select().order_by(t1.c.id)).fetchall()
+        eq_(
+            r, 
+            [
+                (1, ('1', '2', '3'), (5.4, 5.6)), 
+                (2, ('4', '5', '6'), (1.0,)), 
+                (3, (('4', '5'), ('6', '7')), ((5.4, 5.6), (1.0, 1.1)))
+            ]
+        )
+        # hashable
+        eq_(
+            set(row[1] for row in r),
+            set([('1', '2', '3'), ('4', '5', '6'), (('4', '5'), ('6', '7'))])
+        )
+        
+        
+        
 class TimestampTest(TestBase, AssertsExecutionResults):
     __only_on__ = 'postgresql'