From ba1352ae4426f9ea29dd185bea4d1ce7b4ba3fba Mon Sep 17 00:00:00 2001 From: Ants Aasma Date: Thu, 12 Jul 2007 00:31:32 +0000 Subject: [PATCH] Added PGArray datatype for using postgres array datatypes --- CHANGES | 3 +- lib/sqlalchemy/databases/postgres.py | 38 ++++++++++++++++++++++++- test/dialect/postgres.py | 42 ++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 2 deletions(-) diff --git a/CHANGES b/CHANGES index c658896d5e..ce93a6d6d9 100644 --- a/CHANGES +++ b/CHANGES @@ -107,7 +107,8 @@ SelectResultsExt still exist but just return a slightly modified Query object for backwards-compatibility. join_to() method from SelectResults isn't present anymore, need to use join(). - +- postgres + - Added PGArray datatype for using postgres array datatypes 0.3.9 - general - better error message for NoSuchColumnError [ticket:607] diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index e9798967e5..4fd00198b5 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -83,6 +83,39 @@ class PGBoolean(sqltypes.Boolean): def get_col_spec(self): return "BOOLEAN" +class PGArray(sqltypes.TypeEngine): + def __init__(self, item_type): + if isinstance(item_type, type): + item_type = item_type() + self.item_type = item_type + + def dialect_impl(self, dialect): + impl = self.__class__.__new__(self.__class__) + impl.__dict__.update(self.__dict__) + impl.item_type = self.item_type.dialect_impl(dialect) + return impl + def convert_bind_param(self, value, dialect): + if value is None: + return value + def convert_item(item): + if isinstance(item, (list,tuple)): + return [convert_item(child) for child in item] + else: + return self.item_type.convert_bind_param(item, dialect) + return [convert_item(item) for item in value] + def convert_result_value(self, value, dialect): + if value is None: + return value + def convert_item(item): + if isinstance(item, list): + return [convert_item(child) for child in item] + else: + return self.item_type.convert_result_value(item, dialect) + # Could specialcase when item_type.convert_result_value is the default identity func + return [convert_item(item) for item in value] + def get_col_spec(self): + return self.item_type.get_col_spec() + '[]' + colspecs = { sqltypes.Integer : PGInteger, sqltypes.Smallinteger : PGSmallInteger, @@ -306,8 +339,9 @@ class PGDialect(ansisql.ANSIDialect): for name, format_type, default, notnull, attnum, table_oid in rows: ## strip (30) from character varying(30) - attype = re.search('([^\(]+)', format_type).group(1) + attype = re.search('([^\([]+)', format_type).group(1) nullable = not notnull + is_array = format_type.endswith('[]') try: charlen = re.search('\(([\d,]+)\)', format_type).group(1) @@ -360,6 +394,8 @@ class PGDialect(ansisql.ANSIDialect): if coltype: coltype = coltype(*args, **kwargs) + if is_array: + coltype = PGArray(coltype) else: warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (attype, name))) coltype = sqltypes.NULLTYPE diff --git a/test/dialect/postgres.py b/test/dialect/postgres.py index 686417d7a9..7606b35684 100644 --- a/test/dialect/postgres.py +++ b/test/dialect/postgres.py @@ -190,6 +190,48 @@ class TimezoneTest(AssertMixin): x = c.last_updated_params() print x['date'] == somedate +class ArrayTest(AssertMixin): + @testbase.supported('postgres') + def setUpAll(self): + global metadata, arrtable + metadata = MetaData(testbase.db) + + arrtable = Table('arrtable', metadata, + Column('id', Integer, primary_key=True), + Column('intarr', postgres.PGArray(Integer)), + Column('strarr', postgres.PGArray(String), nullable=False) + ) + metadata.create_all() + @testbase.supported('postgres') + def tearDownAll(self): + metadata.drop_all() + @testbase.supported('postgres') + def test_reflect_array_column(self): + metadata2 = MetaData(testbase.db) + tbl = Table('arrtable', metadata2, autoload=True) + self.assertTrue(isinstance(tbl.c.intarr.type, postgres.PGArray)) + self.assertTrue(isinstance(tbl.c.strarr.type, postgres.PGArray)) + self.assertTrue(isinstance(tbl.c.intarr.type.item_type, Integer)) + self.assertTrue(isinstance(tbl.c.strarr.type.item_type, String)) + + @testbase.supported('postgres') + def test_insert_array(self): + arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def']) + results = arrtable.select().execute().fetchall() + self.assertEquals(len(results), 1) + self.assertEquals(results[0]['intarr'], [1,2,3]) + self.assertEquals(results[0]['strarr'], ['abc','def']) + arrtable.delete().execute() + + @testbase.supported('postgres') + def test_array_where(self): + arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def']) + arrtable.insert().execute(intarr=[4,5,6], strarr='ABC') + results = arrtable.select().where(arrtable.c.intarr == [1,2,3]).execute().fetchall() + self.assertEquals(len(results), 1) + self.assertEquals(results[0]['intarr'], [1,2,3]) + arrtable.delete().execute() + if __name__ == "__main__": testbase.main() -- 2.47.3