]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Added PGArray datatype for using postgres array datatypes
authorAnts Aasma <ants.aasma@gmail.com>
Thu, 12 Jul 2007 00:31:32 +0000 (00:31 +0000)
committerAnts Aasma <ants.aasma@gmail.com>
Thu, 12 Jul 2007 00:31:32 +0000 (00:31 +0000)
CHANGES
lib/sqlalchemy/databases/postgres.py
test/dialect/postgres.py

diff --git a/CHANGES b/CHANGES
index c658896d5e7035c6d66cdb5fc52c1e157c92eb08..ce93a6d6d958e1dc0851225296d44030bc4175d1 100644 (file)
--- a/CHANGES
+++ b/CHANGES
     SelectResultsExt still exist but just return a slightly modified
     Query object for backwards-compatibility.  join_to() method 
     from SelectResults isn't present anymore, need to use join(). 
-    
+- postgres
+  - Added PGArray datatype for using postgres array datatypes
 0.3.9
 - general
     - better error message for NoSuchColumnError [ticket:607]
index e9798967e5c0d703e5e5a59cad891f29185b704f..4fd00198b5a52894ee8355eeb674047a97db99e8 100644 (file)
@@ -83,6 +83,39 @@ class PGBoolean(sqltypes.Boolean):
     def get_col_spec(self):
         return "BOOLEAN"
 
+class PGArray(sqltypes.TypeEngine):
+    def __init__(self, item_type):
+        if isinstance(item_type, type):
+            item_type = item_type()
+        self.item_type = item_type
+        
+    def dialect_impl(self, dialect):
+        impl = self.__class__.__new__(self.__class__)
+        impl.__dict__.update(self.__dict__)
+        impl.item_type = self.item_type.dialect_impl(dialect)
+        return impl
+    def convert_bind_param(self, value, dialect):
+        if value is None:
+            return value
+        def convert_item(item):
+            if isinstance(item, (list,tuple)):
+                return [convert_item(child) for child in item]
+            else:
+                return self.item_type.convert_bind_param(item, dialect)
+        return [convert_item(item) for item in value]
+    def convert_result_value(self, value, dialect):
+        if value is None:
+            return value
+        def convert_item(item):
+            if isinstance(item, list):
+                return [convert_item(child) for child in item]
+            else:
+                return self.item_type.convert_result_value(item, dialect)
+        # Could specialcase when item_type.convert_result_value is the default identity func
+        return [convert_item(item) for item in value]
+    def get_col_spec(self):
+        return self.item_type.get_col_spec() + '[]'
+
 colspecs = {
     sqltypes.Integer : PGInteger,
     sqltypes.Smallinteger : PGSmallInteger,
@@ -306,8 +339,9 @@ class PGDialect(ansisql.ANSIDialect):
             
             for name, format_type, default, notnull, attnum, table_oid in rows:
                 ## strip (30) from character varying(30)
-                attype = re.search('([^\(]+)', format_type).group(1)
+                attype = re.search('([^\([]+)', format_type).group(1)
                 nullable = not notnull
+                is_array = format_type.endswith('[]')
 
                 try:
                     charlen = re.search('\(([\d,]+)\)', format_type).group(1)
@@ -360,6 +394,8 @@ class PGDialect(ansisql.ANSIDialect):
 
                 if coltype:
                     coltype = coltype(*args, **kwargs)
+                    if is_array:
+                        coltype = PGArray(coltype)
                 else:
                     warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (attype, name)))
                     coltype = sqltypes.NULLTYPE
index 686417d7a98fec9d92e04cbce1ef559b2c70d84f..7606b356845a2ffd4b12d45083be1bd2754aab80 100644 (file)
@@ -190,6 +190,48 @@ class TimezoneTest(AssertMixin):
         x = c.last_updated_params()
         print x['date'] == somedate
 
+class ArrayTest(AssertMixin):
+    @testbase.supported('postgres')
+    def setUpAll(self):
+        global metadata, arrtable
+        metadata = MetaData(testbase.db)
+        
+        arrtable = Table('arrtable', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('intarr', postgres.PGArray(Integer)),
+            Column('strarr', postgres.PGArray(String), nullable=False)
+        )
+        metadata.create_all()
+    @testbase.supported('postgres')
+    def tearDownAll(self):
+        metadata.drop_all()
     
+    @testbase.supported('postgres')
+    def test_reflect_array_column(self):
+        metadata2 = MetaData(testbase.db)
+        tbl = Table('arrtable', metadata2, autoload=True)
+        self.assertTrue(isinstance(tbl.c.intarr.type, postgres.PGArray))
+        self.assertTrue(isinstance(tbl.c.strarr.type, postgres.PGArray))
+        self.assertTrue(isinstance(tbl.c.intarr.type.item_type, Integer))
+        self.assertTrue(isinstance(tbl.c.strarr.type.item_type, String))
+        
+    @testbase.supported('postgres')
+    def test_insert_array(self):
+        arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def'])
+        results = arrtable.select().execute().fetchall()
+        self.assertEquals(len(results), 1)
+        self.assertEquals(results[0]['intarr'], [1,2,3])
+        self.assertEquals(results[0]['strarr'], ['abc','def'])
+        arrtable.delete().execute()
+
+    @testbase.supported('postgres')
+    def test_array_where(self):
+        arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def'])
+        arrtable.insert().execute(intarr=[4,5,6], strarr='ABC')
+        results = arrtable.select().where(arrtable.c.intarr == [1,2,3]).execute().fetchall()
+        self.assertEquals(len(results), 1)
+        self.assertEquals(results[0]['intarr'], [1,2,3])
+        arrtable.delete().execute()
+
 if __name__ == "__main__":
     testbase.main()