]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed bug in Query involving order_by() in conjunction with
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 6 Nov 2008 23:07:47 +0000 (23:07 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 6 Nov 2008 23:07:47 +0000 (23:07 +0000)
multiple aliases of the same class (will add tests in
[ticket:1218])
- Added a new extension sqlalchemy.ext.serializer.  Provides
Serializer/Deserializer "classes" which mirror Pickle/Unpickle,
as well as dumps() and loads().  This serializer implements
an "external object" pickler which keeps key context-sensitive
objects, including engines, sessions, metadata, Tables/Columns,
and mappers, outside of the pickle stream, and can later
restore the pickle using any engine/metadata/session provider.
This is used not for pickling regular object instances, which are
pickleable without any special logic, but for pickling expression
objects and full Query objects, such that all mapper/engine/session
dependencies can be restored at unpickle time.

CHANGES
doc/build/content/dbengine.txt
doc/build/gen_docstrings.py
lib/sqlalchemy/ext/serializer.py [new file with mode: 0644]
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/util.py
test/ext/alltests.py
test/ext/serializer.py [new file with mode: 0644]
test/orm/query.py

diff --git a/CHANGES b/CHANGES
index 3bf18cfebae46399c4853bac5e55e890040aec6b..c0ac0625f6e8b380b0e054b7c497b41dead92d15 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -42,6 +42,10 @@ CHANGES
       the "please specify primaryjoin" message when determining join
       condition.
 
+    - Fixed bug in Query involving order_by() in conjunction with 
+      multiple aliases of the same class (will add tests in 
+      [ticket:1218])
+      
     - When using Query.join() with an explicit clause for the ON
       clause, the clause will be aliased in terms of the left side
       of the join, allowing scenarios like query(Source).
@@ -119,6 +123,19 @@ CHANGES
     - No longer expects include_columns in table reflection to be
       lower case.
 
+- ext
+    - Added a new extension sqlalchemy.ext.serializer.  Provides
+      Serializer/Deserializer "classes" which mirror Pickle/Unpickle,
+      as well as dumps() and loads().  This serializer implements
+      an "external object" pickler which keeps key context-sensitive
+      objects, including engines, sessions, metadata, Tables/Columns,
+      and mappers, outside of the pickle stream, and can later 
+      restore the pickle using any engine/metadata/session provider.   
+      This is used not for pickling regular object instances, which are 
+      pickleable without any special logic, but for pickling expression
+      objects and full Query objects, such that all mapper/engine/session
+      dependencies can be restored at unpickle time.
+      
 - misc
     - util.flatten_iterator() func doesn't interpret strings with
       __iter__() methods as iterators, such as in pypy [ticket:1077].
index c790685307a659f1d33bf82d056a0fa7d45b350b..9544d0689398adb96526a223dc9fc8263e6fb521 100644 (file)
@@ -139,7 +139,7 @@ A list of all standard options, as well as several that are used by particular d
 * **echo=False** - if True, the Engine will log all statements as well as a repr() of their parameter lists to the engines logger, which defaults to sys.stdout.  The `echo` attribute of `Engine` can be modified at any time to turn logging on and off.  If set to the string `"debug"`, result rows will be printed to the standard output as well.  This flag ultimately controls a Python logger; see [dbengine_logging](rel:dbengine_logging) at the end of this chapter for information on how to configure logging directly.
 * **echo_pool=False** - if True, the connection pool will log all checkouts/checkins to the logging stream, which defaults to sys.stdout.  This flag ultimately controls a Python logger; see [dbengine_logging](rel:dbengine_logging) for information on how to configure logging directly.
 * **encoding='utf-8'** - the encoding to use for all Unicode translations, both by engine-wide unicode conversion as well as the `Unicode` type object.
-* **label_length=None** - optional integer value which limits the size of dynamically generated column labels to that many characters.  If less than 6, labels are generated as "_<counter>".  If `None`, the value of `dialect.max_identifier_length` is used instead.
+* **label_length=None** - optional integer value which limits the size of dynamically generated column labels to that many characters.  If less than 6, labels are generated as "_(counter)".  If `None`, the value of `dialect.max_identifier_length` is used instead.
 * **module=None** - used by database implementations which support multiple DBAPI modules, this is a reference to a DBAPI2 module to be used instead of the engine's default module.  For Postgres, the default is psycopg2.  For Oracle, it's cx_Oracle.
 * **pool=None** - an already-constructed instance of `sqlalchemy.pool.Pool`, such as a `QueuePool` instance.  If non-None, this pool will be used directly as the underlying connection pool for the engine, bypassing whatever connection parameters are present in the URL argument.  For information on constructing connection pools manually, see [pooling](rel:pooling).
 * **poolclass=None** - a `sqlalchemy.pool.Pool` subclass, which will be used to create a connection pool instance using the connection parameters given in the URL.  Note this differs from `pool` in that you don't actually instantiate the pool in this case, you just indicate what type of pool to be used.
index 457f6a9deda253bc71e40414db9b7ce28c8f2c78..ee9bcf0ba3979db0cbadc973b084ca95b9cffe35 100644 (file)
@@ -5,11 +5,8 @@ import re
 from sqlalchemy import schema, types, engine, sql, pool, orm, exceptions, databases, interfaces
 from sqlalchemy.sql import compiler, expression
 from sqlalchemy.engine import default, strategies, threadlocal, url
-import sqlalchemy.orm.shard
-import sqlalchemy.ext.orderinglist as orderinglist
-import sqlalchemy.ext.associationproxy as associationproxy
-import sqlalchemy.ext.sqlsoup as sqlsoup
-import sqlalchemy.ext.declarative as declarative
+from sqlalchemy.orm import shard
+from sqlalchemy.ext import orderinglist, associationproxy, sqlsoup, declarative, serializer
 
 def make_doc(obj, classes=None, functions=None, **kwargs):
     """generate a docstring.ObjectDoc structure for an individual module, list of classes, and list of functions."""
@@ -47,6 +44,7 @@ def make_all_docs():
         make_doc(obj=declarative),
         make_doc(obj=associationproxy, classes=[associationproxy.AssociationProxy]),
         make_doc(obj=orderinglist, classes=[orderinglist.OrderingList]),
+        make_doc(obj=serializer),
         make_doc(obj=sqlsoup),
     ] + [make_doc(getattr(__import__('sqlalchemy.databases.%s' % m).databases, m)) for m in databases.__all__]
     return objects
diff --git a/lib/sqlalchemy/ext/serializer.py b/lib/sqlalchemy/ext/serializer.py
new file mode 100644 (file)
index 0000000..b62ee0c
--- /dev/null
@@ -0,0 +1,129 @@
+"""Serializer/Deserializer objects for usage with SQLAlchemy structures.
+
+Any SQLAlchemy structure, including Tables, Columns, expressions, mappers,
+Query objects etc. can be serialized in a minimally-sized format,
+and deserialized when given a Metadata and optional ScopedSession object
+to use as context on the way out.
+
+Usage is nearly the same as that of the standard Python pickle module::
+
+    from sqlalchemy.ext.serializer import loads, dumps
+    metadata = MetaData(bind=some_engine)
+    Session = scoped_session(sessionmaker())
+    
+    # ... define mappers
+    
+    query = Session.query(MyClass).filter(MyClass.somedata=='foo').order_by(MyClass.sortkey)
+    
+    # pickle the query
+    serialized = dumps(query)
+    
+    # unpickle.  Pass in metadata + scoped_session
+    query2 = loads(serialized, metadata, Session)
+    
+    print query2.all()
+
+Similar restrictions as when using raw pickle apply; mapped classes must be 
+themselves be pickleable, meaning they are importable from a module-level
+namespace.
+
+Note that instances of user-defined classes do not require this extension
+in order to be pickled; these contain no references to engines, sessions
+or expression constructs in the typical case and can be serialized directly.
+This module is specifically for ORM and expression constructs.
+
+"""
+
+from sqlalchemy.orm import class_mapper, Query
+from sqlalchemy.orm.session import Session
+from sqlalchemy.orm.mapper import Mapper
+from sqlalchemy.orm.attributes import QueryableAttribute
+from sqlalchemy import Table, Column
+from sqlalchemy.engine import Engine
+from sqlalchemy.util import pickle
+import re
+import base64
+from cStringIO import StringIO
+
+__all__ = ['Serializer', 'Deserializer', 'dumps', 'loads']
+
+def Serializer(*args, **kw):
+    pickler = pickle.Pickler(*args, **kw)
+        
+    def persistent_id(obj):
+        #print "serializing:", repr(obj)
+        if isinstance(obj, QueryableAttribute):
+            cls = obj.impl.class_
+            key = obj.impl.key
+            id = "attribute:" + key + ":" + base64.b64encode(pickle.dumps(cls))
+        elif isinstance(obj, Mapper) and not obj.non_primary:
+            id = "mapper:" + base64.b64encode(pickle.dumps(obj.class_))
+        elif isinstance(obj, Table):
+            id = "table:" + str(obj)
+        elif isinstance(obj, Column) and isinstance(obj.table, Table):
+            id = "column:" + str(obj.table) + ":" + obj.key
+        elif isinstance(obj, Session):
+            id = "session:"
+        elif isinstance(obj, Engine):
+            id = "engine:"
+        else:
+            return None
+        return id
+        
+    pickler.persistent_id = persistent_id
+    return pickler
+    
+our_ids = re.compile(r'(mapper|table|column|session|attribute|engine):(.*)')
+
+def Deserializer(file, metadata=None, scoped_session=None, engine=None):
+    unpickler = pickle.Unpickler(file)
+    
+    def get_engine():
+        if engine:
+            return engine
+        elif scoped_session and scoped_session().bind:
+            return scoped_session().bind
+        elif metadata and metadata.bind:
+            return metadata.bind
+        else:
+            return None
+            
+    def persistent_load(id):
+        m = our_ids.match(id)
+        if not m:
+            return None
+        else:
+            type_, args = m.group(1, 2)
+            if type_ == 'attribute':
+                key, clsarg = args.split(":")
+                cls = pickle.loads(base64.b64decode(clsarg))
+                return getattr(cls, key)
+            elif type_ == "mapper":
+                cls = pickle.loads(base64.b64decode(args))
+                return class_mapper(cls)
+            elif type_ == "table":
+                return metadata.tables[args]
+            elif type_ == "column":
+                table, colname = args.split(':')
+                return metadata.tables[table].c[colname]
+            elif type_ == "session":
+                return scoped_session()
+            elif type_ == "engine":
+                return get_engine()
+            else:
+                raise Exception("Unknown token: %s" % type_)
+    unpickler.persistent_load = persistent_load
+    return unpickler
+
+def dumps(obj):
+    buf = StringIO()
+    pickler = Serializer(buf)
+    pickler.dump(obj)
+    return buf.getvalue()
+    
+def loads(data, metadata=None, scoped_session=None, engine=None):
+    buf = StringIO(data)
+    unpickler = Deserializer(buf, metadata, scoped_session, engine)
+    return unpickler.load()
+    
+    
\ No newline at end of file
index aa30f151744bb9ce4043071b72757cf37322199d..51165287fb181b62519389b79eea336f51fcdd29 100644 (file)
@@ -195,12 +195,12 @@ class Query(object):
         if as_filter and self._filter_aliases:
             adapters.append(self._filter_aliases.replace)
 
-        if self._polymorphic_adapters:
-            adapters.append(self.__adapt_polymorphic_element)
-
         if self._from_obj_alias:
             adapters.append(self._from_obj_alias.replace)
 
+        if self._polymorphic_adapters:
+            adapters.append(self.__adapt_polymorphic_element)
+
         if not adapters:
             return clause
 
@@ -1707,9 +1707,9 @@ class _MapperEntity(_QueryEntity):
         if context.order_by is False and self.mapper.order_by:
             context.order_by = self.mapper.order_by
 
-        if context.order_by and adapter:
-            context.order_by = adapter.adapt_list(util.to_list(context.order_by))
-
+            if adapter:
+                context.order_by = adapter.adapt_list(util.to_list(context.order_by))
+                    
         for value in self.mapper._iterate_polymorphic_properties(self._with_polymorphic):
             if query._only_load_props and value.key not in query._only_load_props:
                 continue
index 689df8d867cfd592f2c307c3c0812ec32255b012..405acda15f93509e0bca0f56ab039ebe0ea0fd35 100644 (file)
@@ -281,6 +281,19 @@ class AliasedClass(object):
         self._sa_label_name = name
         self.__name__ = 'AliasedClass_' + str(self.__target)
 
+    def __getstate__(self):
+        return {'mapper':self.__mapper, 'alias':self.__alias, 'name':self._sa_label_name}
+    
+    def __setstate__(self, state):
+        self.__mapper = state['mapper']
+        self.__target = self.__mapper.class_
+        alias = state['alias']
+        self.__adapter = sql_util.ClauseAdapter(alias, equivalents=self.__mapper._equivalent_columns)
+        self.__alias = alias
+        name = state['name']
+        self._sa_label_name = name
+        self.__name__ = 'AliasedClass_' + str(self.__target)
+        
     def __adapt_element(self, elem):
         return self.__adapter.traverse(elem)._annotate({'parententity': self})
         
index 85f229ba0da380c827a3994987f913f7f2e07334..5206dc5fa8c83b6d800a3dde7cbced5696615054 100644 (file)
@@ -1002,6 +1002,11 @@ class ClauseElement(Visitable):
             yield f
             f = getattr(f, '_is_clone_of', None)
 
+    def __getstate__(self):
+        d = self.__dict__.copy()
+        d.pop('_is_clone_of', None)
+        return d
+        
     def _get_from_objects(self, **modifiers):
         """Return objects represented in this ``ClauseElement`` that
         should be added to the ``FROM`` list of a query, when this
@@ -1959,7 +1964,17 @@ class _BindParamClause(ColumnElement):
 
         """
         return isinstance(other, _BindParamClause) and other.type.__class__ == self.type.__class__
-
+    
+    def __getstate__(self):
+        """execute a deferred value for serialization purposes."""
+        
+        d = self.__dict__.copy()
+        v = self.value
+        if callable(v):
+            v = v()
+        d['value'] = v
+        return d
+        
     def __repr__(self):
         return "_BindParamClause(%s, %s, type_=%s)" % (repr(self.key), repr(self.value), repr(self.type))
 
index 2a510906b1275947a57481e5ce8c76c83ddf54af..d5f2417c27bce1b4756b94766e14db0b3dbf59cf 100644 (file)
@@ -121,6 +121,7 @@ def join_condition(a, b, ignore_nonexistent_tables=False):
     else:
         return sql.and_(*crit)
 
+    
 class Annotated(object):
     """clones a ClauseElement and applies an 'annotations' dictionary.
     
@@ -133,14 +134,17 @@ class Annotated(object):
     hash value may be reused, causing conflicts.
 
     """
+    
     def __new__(cls, *args):
         if not args:
+            # clone constructor
             return object.__new__(cls)
         else:
             element, values = args
-            return object.__new__(
-                type.__new__(type, "Annotated%s" % element.__class__.__name__, (Annotated, element.__class__), {}) 
-            )
+            # pull appropriate subclass from this module's 
+            # namespace (see below for rationale)
+            cls = eval("Annotated%s"  % element.__class__.__name__)
+            return object.__new__(cls)
 
     def __init__(self, element, values):
         # force FromClause to generate their internal 
@@ -180,6 +184,17 @@ class Annotated(object):
     def __cmp__(self, other):
         return cmp(hash(self.__element), hash(other))
 
+# hard-generate Annotated subclasses.  this technique
+# is used instead of on-the-fly types (i.e. type.__new__())
+# so that the resulting objects are pickleable.
+from sqlalchemy.sql import expression
+for cls in expression.__dict__.values() + [schema.Column, schema.Table]:
+    if isinstance(cls, type) and issubclass(cls, expression.ClauseElement):
+        exec "class Annotated%s(Annotated, cls):\n" \
+             "    __visit_name__ = cls.__visit_name__\n"\
+             "    pass" % (cls.__name__, ) in locals()
+
+
 def _deep_annotate(element, annotations, exclude=None):
     """Deep copy the given ClauseElement, annotating each element with the given annotations dictionary.
 
@@ -495,3 +510,11 @@ class ColumnAdapter(ClauseAdapter):
     def adapted_row(self, row):
         return AliasedRow(row, self.columns)
     
+    def __getstate__(self):
+        d = self.__dict__.copy()
+        del d['columns']
+        return d
+        
+    def __setstate__(self, state):
+        self.__dict__.update(state)
+        self.columns = util.PopulateDict(self._locate_col)
index ff645b335bd4f693b7444bfd8c77047c4e96fa3b..d67f6bf428f4c65c50a4c916214b5c38c46905db 100644 (file)
@@ -7,6 +7,7 @@ def suite():
         'ext.declarative',
         'ext.orderinglist',
         'ext.associationproxy',
+        'ext.serializer',
         )
 
     if sys.version_info < (2, 4):
diff --git a/test/ext/serializer.py b/test/ext/serializer.py
new file mode 100644 (file)
index 0000000..0a900a9
--- /dev/null
@@ -0,0 +1,127 @@
+import testenv; testenv.configure_for_tests()
+
+from sqlalchemy.ext import serializer
+from sqlalchemy import exc
+from testlib import sa, testing
+from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey, select, desc, func
+from testlib.sa.orm import relation, sessionmaker, scoped_session, class_mapper, mapper, eagerload, compile_mappers, aliased
+from testlib.testing import eq_
+from orm._base import ComparableEntity, MappedTest
+
+
+class User(ComparableEntity):
+    pass
+
+class Address(ComparableEntity):
+    pass
+
+class SerializeTest(testing.ORMTest):
+    keep_mappers = True
+    keep_data = True
+    
+    def define_tables(self, metadata):
+        global users, addresses
+        users = Table('users', metadata, 
+            Column('id', Integer, primary_key=True),
+            Column('name', String(50))
+        )
+        addresses = Table('addresses', metadata, 
+            Column('id', Integer, primary_key=True),
+            Column('email', String(50)),
+            Column('user_id', Integer, ForeignKey('users.id')),
+        )
+
+    def setup_mappers(self):
+        global Session
+        Session = scoped_session(sessionmaker())
+
+        mapper(User, users, properties={
+            'addresses':relation(Address, backref='user', order_by=addresses.c.id)
+        })
+        mapper(Address, addresses)
+
+        compile_mappers()
+        
+    def insert_data(self):
+        params = [dict(zip(('id', 'name'), column_values)) for column_values in 
+            [(7, 'jack'),
+            (8, 'ed'),
+            (9, 'fred'),
+            (10, 'chuck')]
+        ]
+        users.insert().execute(params)
+    
+        addresses.insert().execute(
+            [dict(zip(('id', 'user_id', 'email'), column_values)) for column_values in 
+                [(1, 7, "jack@bean.com"),
+                (2, 8, "ed@wood.com"),
+                (3, 8, "ed@bettyboop.com"),
+                (4, 8, "ed@lala.com"),
+                (5, 9, "fred@fred.com")]
+            ]
+        )
+    
+    def test_tables(self):
+        assert serializer.loads(serializer.dumps(users), users.metadata, Session) is users
+
+    def test_columns(self):
+        assert serializer.loads(serializer.dumps(users.c.name), users.metadata, Session) is users.c.name
+        
+    def test_mapper(self):
+        user_mapper = class_mapper(User)
+        assert serializer.loads(serializer.dumps(user_mapper), None, None) is user_mapper
+    
+    def test_attribute(self):
+        assert serializer.loads(serializer.dumps(User.name), None, None) is User.name
+    
+    def test_expression(self):
+        
+        expr = select([users]).select_from(users.join(addresses)).limit(5)
+        re_expr = serializer.loads(serializer.dumps(expr), users.metadata, None)
+        eq_(
+            str(expr), 
+            str(re_expr)
+        )
+        
+        assert re_expr.bind is testing.db
+        eq_(
+            re_expr.execute().fetchall(),
+            [(7, u'jack'), (8, u'ed'), (8, u'ed'), (8, u'ed'), (9, u'fred')]
+        )
+        
+    def test_query(self):
+        q = Session.query(User).filter(User.name=='ed').options(eagerload(User.addresses))
+        eq_(q.all(), [User(name='ed', addresses=[Address(id=2), Address(id=3), Address(id=4)])])
+        
+        q2 = serializer.loads(serializer.dumps(q), users.metadata, Session)
+        def go():
+            eq_(q2.all(), [User(name='ed', addresses=[Address(id=2), Address(id=3), Address(id=4)])])
+        self.assert_sql_count(testing.db, go, 1)
+        
+        eq_(q2.join(User.addresses).filter(Address.email=='ed@bettyboop.com').value(func.count('*')), 1)
+
+        u1 = Session.query(User).get(8)
+        
+        q = Session.query(Address).filter(Address.user==u1).order_by(desc(Address.email))
+        q2 = serializer.loads(serializer.dumps(q), users.metadata, Session)
+        
+        eq_(q2.all(), [Address(email='ed@wood.com'), Address(email='ed@lala.com'), Address(email='ed@bettyboop.com')])
+        
+        q = Session.query(User).join(User.addresses).filter(Address.email.like('%fred%'))
+        q2 = serializer.loads(serializer.dumps(q), users.metadata, Session)
+        eq_(q2.all(), [User(name='fred')])
+        
+        eq_(list(q2.values(User.id, User.name)), [(9, u'fred')])
+
+    def test_aliases(self):
+        u7, u8, u9, u10 = Session.query(User).order_by(User.id).all()
+
+        ualias = aliased(User)
+        q = Session.query(User, ualias).join((ualias, User.id < ualias.id)).filter(User.id<9).order_by(User.id, ualias.id)
+
+        q2 = serializer.loads(serializer.dumps(q), users.metadata, Session)
+        
+        eq_(list(q2.all()), [(u7, u8), (u7, u9), (u7, u10), (u8, u9), (u8, u10)])
+        
+if __name__ == '__main__':
+    testing.main()
index e50cacc86eff3ac9ee811c86b469853422b401ad..b63c12f09216295c91d42c45edbb0f0c49afbf19 100644 (file)
@@ -1028,7 +1028,7 @@ class JoinTest(QueryTest):
         q = sess.query(User)
         AdAlias = aliased(Address)
         q = q.add_entity(AdAlias).select_from(outerjoin(User, AdAlias))
-        l = q.order_by(User.id, Address.id).all()
+        l = q.order_by(User.id, AdAlias.id).all()
         self.assertEquals(l, expected)
 
         sess.clear()
@@ -1705,7 +1705,7 @@ class MixedEntitiesTest(QueryTest):
         q = sess.query(User)
         adalias = addresses.alias('adalias')
         q = q.add_entity(Address, alias=adalias).select_from(users.outerjoin(adalias))
-        l = q.order_by(User.id, Address.id).all()
+        l = q.order_by(User.id, adalias.c.id).all()
         assert l == expected
 
         sess.clear()