From 42ac34a7019edc79f204576237cce23c107c50ca Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 20 Oct 2010 16:17:17 -0400 Subject: [PATCH] - Added "as_tuple" flag to pg ARRAY type, returns results as tuples instead of lists to allow hashing. --- CHANGES | 4 ++ lib/sqlalchemy/dialects/postgresql/base.py | 34 ++++++++++++---- test/dialect/test_postgresql.py | 46 +++++++++++++++++++--- 3 files changed, 71 insertions(+), 13 deletions(-) diff --git a/CHANGES b/CHANGES index 63761d6234..04c75ac596 100644 --- a/CHANGES +++ b/CHANGES @@ -194,6 +194,10 @@ CHANGES boolean values for 'use_native_unicode'. [ticket:1899] +- postgresql + - Added "as_tuple" flag to ARRAY type, returns results + as tuples instead of lists to allow hashing. + - mssql - Fixed reflection bug which did not properly handle reflection of unknown types. [ticket:1946] diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 89769b8c0a..03260cdb34 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -171,7 +171,7 @@ class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine): """ __visit_name__ = 'ARRAY' - def __init__(self, item_type, mutable=True): + def __init__(self, item_type, mutable=True, as_tuple=False): """Construct an ARRAY. E.g.:: @@ -186,9 +186,14 @@ class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine): ``ARRAY(ARRAY(Integer))`` or such. The type mapping figures out on the fly - :param mutable: Defaults to True: specify whether lists passed to this + :param mutable=True: Specify whether lists passed to this class should be considered mutable. If so, generic copy operations (typically used by the ORM) will shallow-copy values. + + :param as_tuple=False: Specify whether return results should be converted + to tuples from lists. DBAPIs such as psycopg2 return lists by default. + When tuples are returned, the results are hashable. This flag can only + be set to ``True`` when ``mutable`` is set to ``False``. (new in 0.6.5) """ if isinstance(item_type, ARRAY): @@ -198,7 +203,12 @@ class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine): item_type = item_type() self.item_type = item_type self.mutable = mutable - + if mutable and as_tuple: + raise exc.ArgumentError( + "mutable must be set to False if as_tuple is True." + ) + self.as_tuple = as_tuple + def copy_value(self, value): if value is None: return None @@ -224,7 +234,8 @@ class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine): def adapt(self, impltype): return impltype( self.item_type, - mutable=self.mutable + mutable=self.mutable, + as_tuple=self.as_tuple ) def bind_processor(self, dialect): @@ -252,19 +263,28 @@ class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine): if item_proc: def convert_item(item): if isinstance(item, list): - return [convert_item(child) for child in item] + r = [convert_item(child) for child in item] + if self.as_tuple: + r = tuple(r) + return r else: return item_proc(item) else: def convert_item(item): if isinstance(item, list): - return [convert_item(child) for child in item] + r = [convert_item(child) for child in item] + if self.as_tuple: + r = tuple(r) + return r else: return item def process(value): if value is None: return value - return [convert_item(item) for item in value] + r = [convert_item(item) for item in value] + if self.as_tuple: + r = tuple(r) + return r return process PGArray = ARRAY diff --git a/test/dialect/test_postgresql.py b/test/dialect/test_postgresql.py index 9ad46c189c..36aa7f2c6f 100644 --- a/test/dialect/test_postgresql.py +++ b/test/dialect/test_postgresql.py @@ -1492,8 +1492,8 @@ class ArrayTest(TestBase, AssertsExecutionResults): metadata = MetaData(testing.db) arrtable = Table('arrtable', metadata, Column('id', Integer, primary_key=True), Column('intarr', - postgresql.PGArray(Integer)), Column('strarr', - postgresql.PGArray(Unicode()), nullable=False)) + postgresql.ARRAY(Integer)), Column('strarr', + postgresql.ARRAY(Unicode()), nullable=False)) metadata.create_all() def teardown(self): @@ -1506,8 +1506,8 @@ class ArrayTest(TestBase, AssertsExecutionResults): def test_reflect_array_column(self): metadata2 = MetaData(testing.db) tbl = Table('arrtable', metadata2, autoload=True) - assert isinstance(tbl.c.intarr.type, postgresql.PGArray) - assert isinstance(tbl.c.strarr.type, postgresql.PGArray) + assert isinstance(tbl.c.intarr.type, postgresql.ARRAY) + assert isinstance(tbl.c.strarr.type, postgresql.ARRAY) assert isinstance(tbl.c.intarr.type.item_type, Integer) assert isinstance(tbl.c.strarr.type.item_type, String) @@ -1575,7 +1575,7 @@ class ArrayTest(TestBase, AssertsExecutionResults): footable = Table('foo', metadata, Column('id', Integer, primary_key=True), Column('intarr', - postgresql.PGArray(Integer), nullable=True)) + postgresql.ARRAY(Integer), nullable=True)) mapper(Foo, footable) metadata.create_all() sess = create_session() @@ -1607,7 +1607,41 @@ class ArrayTest(TestBase, AssertsExecutionResults): foo.id = 2 sess.add(foo) sess.flush() - + + @testing.provide_metadata + def test_tuple_flag(self): + assert_raises_message( + exc.ArgumentError, + "mutable must be set to False if as_tuple is True.", + postgresql.ARRAY, Integer, as_tuple=True) + + t1 = Table('t1', metadata, + Column('id', Integer, primary_key=True), + Column('data', postgresql.ARRAY(String(5), as_tuple=True, mutable=False)), + Column('data2', postgresql.ARRAY(Numeric(asdecimal=False), as_tuple=True, mutable=False)), + ) + metadata.create_all() + testing.db.execute(t1.insert(), id=1, data=["1","2","3"], data2=[5.4, 5.6]) + testing.db.execute(t1.insert(), id=2, data=["4", "5", "6"], data2=[1.0]) + testing.db.execute(t1.insert(), id=3, data=[["4", "5"], ["6", "7"]], data2=[[5.4, 5.6], [1.0, 1.1]]) + + r = testing.db.execute(t1.select().order_by(t1.c.id)).fetchall() + eq_( + r, + [ + (1, ('1', '2', '3'), (5.4, 5.6)), + (2, ('4', '5', '6'), (1.0,)), + (3, (('4', '5'), ('6', '7')), ((5.4, 5.6), (1.0, 1.1))) + ] + ) + # hashable + eq_( + set(row[1] for row in r), + set([('1', '2', '3'), ('4', '5', '6'), (('4', '5'), ('6', '7'))]) + ) + + + class TimestampTest(TestBase, AssertsExecutionResults): __only_on__ = 'postgresql' -- 2.47.2