From: Mike Bayer Date: Sun, 11 Oct 2009 17:16:53 +0000 (+0000) Subject: - RowProxy objects are now pickleable, i.e. the object returned X-Git-Tag: rel_0_6beta1~260 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=114ad36894ab37280106feb15e5421ac124c6834;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - RowProxy objects are now pickleable, i.e. the object returned by result.fetchone(), result.fetchall() etc. - the "named tuple" objects returned when iterating a Query() are now pickleable. --- diff --git a/CHANGES b/CHANGES index add2ece7c0..44bd7e577a 100644 --- a/CHANGES +++ b/CHANGES @@ -22,6 +22,15 @@ CHANGES - query.update() and query.delete() both default to 'evaluate' for the synchronize strategy. + + - the 'synchronize' strategy for update() and delete() raises + an error on failure. There is no implicit fallback onto + "fetch". Failure of evaluation is based on the structure of + criteria, so success/failure is deterministic based on code + structure. + + - the "named tuple" objects returned when iterating a + Query() are now pickleable. - query.join() has been reworked to provide more consistent behavior and more flexibility (includes [ticket:1537]) @@ -30,12 +39,6 @@ CHANGES multiple comma separated entries within the FROM clause. Useful when selecting from multiple-homed join() clauses. - - the 'synchronize' strategy for update() and delete() raises - an error on failure. There is no implicit fallback onto - "fetch". Failure of evaluation is based on the structure of - criteria, so success/failure is deterministic based on code - structure. - - the "dont_load=True" flag on Session.merge() is deprecated and is now "load=False". @@ -121,6 +124,9 @@ CHANGES - added first() method to ResultProxy, returns first row and closes result set immediately. + - RowProxy objects are now pickleable, i.e. the object returned + by result.fetchone(), result.fetchall() etc. + - schema - deprecated MetaData.connect() and ThreadLocalMetaData.connect() have been removed - send diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 70ee295db4..26e44dd6b0 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1497,7 +1497,7 @@ class RowProxy(object): self.__row = row if self.__parent._echo: self.__parent.context.engine.logger.debug("Row %r", row) - + def close(self): """Close the parent ResultProxy.""" @@ -1508,7 +1508,17 @@ class RowProxy(object): def __len__(self): return len(self.__row) - + + def __getstate__(self): + return { + '__row':[self.__parent._get_col(self.__row, i) for i in xrange(len(self.__row))], + '__parent':PickledResultProxy(self.__parent) + } + + def __setstate__(self, d): + self.__row = d['__row'] + self.__parent = d['__parent'] + def __iter__(self): for i in xrange(len(self.__row)): yield self.__parent._get_col(self.__row, i) @@ -1561,7 +1571,52 @@ class RowProxy(object): def itervalues(self): return iter(self) - +class PickledResultProxy(object): + """a 'mock' ResultProxy used by a RowProxy being pickled.""" + + _echo = False + + def __init__(self, resultproxy): + self._props = dict( + (k, resultproxy._props[k][2]) for k in resultproxy._props + if isinstance(k, (basestring, int)) + ) + self._keys = resultproxy.keys + + def _fallback_key(self, key): + if key in self._props: + return self._props[key] + + if isinstance(key, basestring): + key = key.lower() + if key in self._props: + return self._props[key] + + if isinstance(key, expression.ColumnElement): + if key._label and key._label.lower() in self._props: + return self._props[key._label.lower()] + elif hasattr(key, 'name') and key.name.lower() in self._props: + return self._props[key.name.lower()] + + return None + + def close(self): + pass + + def _has_key(self, row, key): + return self._fallback_key(key) is not None + + def _get_col(self, row, orig_key): + key = self._fallback_key(orig_key) + if key is None: + raise exc.NoSuchColumnError("Could not locate column in row for column '%s'" % orig_key) + return row[key] + + @property + def keys(self): + return self._keys + + class BufferedColumnRow(RowProxy): def __init__(self, parent, row): row = [ResultProxy._get_col(parent, row, i) for i in xrange(len(row))] @@ -1639,7 +1694,7 @@ class ResultProxy(object): """ return self.cursor.lastrowid - + def _cursor_description(self): return self.cursor.description @@ -1732,7 +1787,7 @@ class ResultProxy(object): elif hasattr(key, 'name') and key.name.lower() in props: return props[key.name.lower()] - raise exc.NoSuchColumnError("Could not locate column in row for column '%s'" % (str(key))) + raise exc.NoSuchColumnError("Could not locate column in row for column '%s'" % key) return fallback def __ambiguous_processor(self, colname): diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index b1b85cd01c..b347e205e8 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1371,11 +1371,7 @@ class Query(object): (process, labels) = zip(*[query_entity.row_processor(self, context, custom_rows) for query_entity in self._entities]) if not single_entity: - labels = dict((label, property(itemgetter(i))) - for i, label in enumerate(labels) - if label) - rowtuple = type.__new__(type, "RowTuple", (tuple,), labels) - rowtuple.keys = labels.keys + labels = [l for l in labels if l] while True: context.progress = {} @@ -1395,7 +1391,7 @@ class Query(object): elif single_entity: rows = [process[0](context, row) for row in fetch] else: - rows = [rowtuple(proc(context, row) for proc in process) + rows = [util.NamedTuple(labels, (proc(context, row) for proc in process)) for row in fetch] if filter: diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 67990a2028..8f0b5583dd 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -635,6 +635,23 @@ def monkeypatch_proxied_specials(into_cls, from_cls, skip=None, only=None, pass setattr(into_cls, method, env[method]) +class NamedTuple(tuple): + """tuple() subclass that adds labeled names. + + Is also pickleable. + + """ + + def __new__(cls, labels, vals): + vals = list(vals) + t = tuple.__new__(cls, vals) + t.__dict__ = dict(zip(labels, vals)) + t._labels = labels + return t + + def keys(self): + return self._labels + class OrderedProperties(object): """An object that maintains the order in which attributes are set upon it. diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 31547d16db..0ec2b998d9 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -2315,32 +2315,44 @@ class MixedEntitiesTest(QueryTest): def test_tuple_labeling(self): sess = create_session() - for row in sess.query(User, Address).join(User.addresses).all(): - eq_(set(row.keys()), set(['User', 'Address'])) - eq_(row.User, row[0]) - eq_(row.Address, row[1]) - for row in sess.query(User.name, User.id.label('foobar')): - eq_(set(row.keys()), set(['name', 'foobar'])) - eq_(row.name, row[0]) - eq_(row.foobar, row[1]) - - for row in sess.query(User).values(User.name, User.id.label('foobar')): - eq_(set(row.keys()), set(['name', 'foobar'])) - eq_(row.name, row[0]) - eq_(row.foobar, row[1]) - - oalias = aliased(Order) - for row in sess.query(User, oalias).join(User.orders).all(): - eq_(set(row.keys()), set(['User'])) - eq_(row.User, row[0]) - - oalias = aliased(Order, name='orders') - for row in sess.query(User, oalias).join(User.orders).all(): - eq_(set(row.keys()), set(['User', 'orders'])) - eq_(row.User, row[0]) - eq_(row.orders, row[1]) + for pickled in False, True: + for row in sess.query(User, Address).join(User.addresses).all(): + if pickled: + row = util.pickle.loads(util.pickle.dumps(row)) + + eq_(set(row.keys()), set(['User', 'Address'])) + eq_(row.User, row[0]) + eq_(row.Address, row[1]) + + for row in sess.query(User.name, User.id.label('foobar')): + if pickled: + row = util.pickle.loads(util.pickle.dumps(row)) + eq_(set(row.keys()), set(['name', 'foobar'])) + eq_(row.name, row[0]) + eq_(row.foobar, row[1]) + + for row in sess.query(User).values(User.name, User.id.label('foobar')): + if pickled: + row = util.pickle.loads(util.pickle.dumps(row)) + eq_(set(row.keys()), set(['name', 'foobar'])) + eq_(row.name, row[0]) + eq_(row.foobar, row[1]) + oalias = aliased(Order) + for row in sess.query(User, oalias).join(User.orders).all(): + if pickled: + row = util.pickle.loads(util.pickle.dumps(row)) + eq_(set(row.keys()), set(['User'])) + eq_(row.User, row[0]) + + oalias = aliased(Order, name='orders') + for row in sess.query(User, oalias).join(User.orders).all(): + if pickled: + row = util.pickle.loads(util.pickle.dumps(row)) + eq_(set(row.keys()), set(['User', 'orders'])) + eq_(row.User, row[0]) + eq_(row.orders, row[1]) def test_column_queries(self): sess = create_session() diff --git a/test/sql/test_query.py b/test/sql/test_query.py index 3222ff6ef4..470a694fb9 100644 --- a/test/sql/test_query.py +++ b/test/sql/test_query.py @@ -1,10 +1,10 @@ from sqlalchemy.test.testing import eq_ import datetime from sqlalchemy import * -from sqlalchemy import exc, sql +from sqlalchemy import exc, sql, util from sqlalchemy.engine import default from sqlalchemy.test import * -from sqlalchemy.test.testing import eq_, assert_raises_message +from sqlalchemy.test.testing import eq_, assert_raises_message, assert_raises from sqlalchemy.test.schema import Table, Column class QueryTest(TestBase): @@ -207,7 +207,7 @@ class QueryTest(TestBase): for row in select([sel + 1, sel + 3], bind=users.bind).execute(): assert row['anon_1'] == 8 assert row['anon_2'] == 10 - + @testing.fails_on('firebird', "kinterbasdb doesn't send full type information") def test_order_by_label(self): """test that a label within an ORDER BY works on each backend. @@ -260,6 +260,47 @@ class QueryTest(TestBase): self.assert_(not (rp != equal)) self.assert_(not (equal != equal)) + def test_pickled_rows(self): + users.insert().execute( + {'user_id':7, 'user_name':'jack'}, + {'user_id':8, 'user_name':'ed'}, + {'user_id':9, 'user_name':'fred'}, + ) + + for pickle in False, True: + for use_labels in False, True: + result = users.select(use_labels=use_labels).order_by(users.c.user_id).execute().fetchall() + + if pickle: + result = util.pickle.loads(util.pickle.dumps(result)) + + eq_( + result, + [(7, "jack"), (8, "ed"), (9, "fred")] + ) + if use_labels: + eq_(result[0]['query_users_user_id'], 7) + eq_(result[0].keys(), ["query_users_user_id", "query_users_user_name"]) + else: + eq_(result[0]['user_id'], 7) + eq_(result[0].keys(), ["user_id", "user_name"]) + + eq_(result[0][0], 7) + eq_(result[0][users.c.user_id], 7) + eq_(result[0][users.c.user_name], 'jack') + + if use_labels: + assert_raises(exc.NoSuchColumnError, lambda: result[0][addresses.c.user_id]) + else: + # test with a different table. name resolution is + # causing 'user_id' to match when use_labels wasn't used. + eq_(result[0][addresses.c.user_id], 7) + + assert_raises(exc.NoSuchColumnError, lambda: result[0]['fake key']) + assert_raises(exc.NoSuchColumnError, lambda: result[0][addresses.c.address_id]) + + + @testing.requires.boolean_col_expressions def test_or_and_as_columns(self): true, false = literal(True), literal(False)