def get_col_spec(self):
return "BOOLEAN"
+class PGArray(sqltypes.TypeEngine):
+ def __init__(self, item_type):
+ if isinstance(item_type, type):
+ item_type = item_type()
+ self.item_type = item_type
+
+ def dialect_impl(self, dialect):
+ impl = self.__class__.__new__(self.__class__)
+ impl.__dict__.update(self.__dict__)
+ impl.item_type = self.item_type.dialect_impl(dialect)
+ return impl
+ def convert_bind_param(self, value, dialect):
+ if value is None:
+ return value
+ def convert_item(item):
+ if isinstance(item, (list,tuple)):
+ return [convert_item(child) for child in item]
+ else:
+ return self.item_type.convert_bind_param(item, dialect)
+ return [convert_item(item) for item in value]
+ def convert_result_value(self, value, dialect):
+ if value is None:
+ return value
+ def convert_item(item):
+ if isinstance(item, list):
+ return [convert_item(child) for child in item]
+ else:
+ return self.item_type.convert_result_value(item, dialect)
+ # Could specialcase when item_type.convert_result_value is the default identity func
+ return [convert_item(item) for item in value]
+ def get_col_spec(self):
+ return self.item_type.get_col_spec() + '[]'
+
colspecs = {
sqltypes.Integer : PGInteger,
sqltypes.Smallinteger : PGSmallInteger,
for name, format_type, default, notnull, attnum, table_oid in rows:
## strip (30) from character varying(30)
- attype = re.search('([^\(]+)', format_type).group(1)
+ attype = re.search('([^\([]+)', format_type).group(1)
nullable = not notnull
+ is_array = format_type.endswith('[]')
try:
charlen = re.search('\(([\d,]+)\)', format_type).group(1)
if coltype:
coltype = coltype(*args, **kwargs)
+ if is_array:
+ coltype = PGArray(coltype)
else:
warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (attype, name)))
coltype = sqltypes.NULLTYPE
x = c.last_updated_params()
print x['date'] == somedate
+class ArrayTest(AssertMixin):
+ @testbase.supported('postgres')
+ def setUpAll(self):
+ global metadata, arrtable
+ metadata = MetaData(testbase.db)
+
+ arrtable = Table('arrtable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('intarr', postgres.PGArray(Integer)),
+ Column('strarr', postgres.PGArray(String), nullable=False)
+ )
+ metadata.create_all()
+ @testbase.supported('postgres')
+ def tearDownAll(self):
+ metadata.drop_all()
+ @testbase.supported('postgres')
+ def test_reflect_array_column(self):
+ metadata2 = MetaData(testbase.db)
+ tbl = Table('arrtable', metadata2, autoload=True)
+ self.assertTrue(isinstance(tbl.c.intarr.type, postgres.PGArray))
+ self.assertTrue(isinstance(tbl.c.strarr.type, postgres.PGArray))
+ self.assertTrue(isinstance(tbl.c.intarr.type.item_type, Integer))
+ self.assertTrue(isinstance(tbl.c.strarr.type.item_type, String))
+
+ @testbase.supported('postgres')
+ def test_insert_array(self):
+ arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def'])
+ results = arrtable.select().execute().fetchall()
+ self.assertEquals(len(results), 1)
+ self.assertEquals(results[0]['intarr'], [1,2,3])
+ self.assertEquals(results[0]['strarr'], ['abc','def'])
+ arrtable.delete().execute()
+
+ @testbase.supported('postgres')
+ def test_array_where(self):
+ arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def'])
+ arrtable.insert().execute(intarr=[4,5,6], strarr='ABC')
+ results = arrtable.select().where(arrtable.c.intarr == [1,2,3]).execute().fetchall()
+ self.assertEquals(len(results), 1)
+ self.assertEquals(results[0]['intarr'], [1,2,3])
+ arrtable.delete().execute()
+
if __name__ == "__main__":
testbase.main()