]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- renamed query.slice_() to query.slice()
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 13 May 2008 20:35:41 +0000 (20:35 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 13 May 2008 20:35:41 +0000 (20:35 +0000)
- pulled out DeclarativeMeta.__init__ into its own function, added instrument_declarative()
which will do the "declarative" thing to any class independent of its lineage (for ctheune)
- added "cls" kwarg to declarative_base() allowing user-defined base class for declarative base [ticket:1042]

lib/sqlalchemy/ext/declarative.py
lib/sqlalchemy/orm/query.py
test/ext/declarative.py
test/orm/query.py

index 4778b9eba4c570f24ff478d48a265efe65e18ee4..b29f051b15e203402c7df38b8b64c34a8e624d5c 100644 (file)
@@ -188,79 +188,93 @@ from sqlalchemy import util, exceptions
 from sqlalchemy.sql import util as sql_util
 
 
-__all__ = 'declarative_base', 'synonym_for', 'comparable_using'
+__all__ = 'declarative_base', 'synonym_for', 'comparable_using', 'instrument_declarative'
 
+def instrument_declarative(cls, registry, metadata):
+    """Given a class, configure the class declaratively,
+    using the given registry (any dictionary) and MetaData object.
+    This operation does not assume any kind of class hierarchy.
+    
+    """
+    if '_decl_class_registry' in cls.__dict__:
+        raise exceptions.InvalidRequestError("Class %r already has been instrumented declaratively" % cls)
+    cls._decl_class_registry = registry
+    cls.metadata = metadata
+    _as_declarative(cls, cls.__name__, cls.__dict__)
+    
+def _as_declarative(cls, classname, dict_):
+    cls._decl_class_registry[classname] = cls
+    our_stuff = util.OrderedDict()
+    for k in dict_:
+        value = dict_[k]
+        if (isinstance(value, tuple) and len(value) == 1 and
+            isinstance(value[0], (Column, MapperProperty))):
+            util.warn("Ignoring declarative-like tuple value of attribute "
+                      "%s: possibly a copy-and-paste error with a comma "
+                      "left at the end of the line?" % k)
+            continue
+        if not isinstance(value, (Column, MapperProperty)):
+            continue
+        prop = _deferred_relation(cls, value)
+        our_stuff[k] = prop
+
+    # set up attributes in the order they were created
+    our_stuff.sort(lambda x, y: cmp(our_stuff[x]._creation_order,
+                                    our_stuff[y]._creation_order))
+
+    table = None
+    if '__table__' not in cls.__dict__:
+        if '__tablename__' in cls.__dict__:
+            tablename = cls.__tablename__
+            autoload = cls.__dict__.get('__autoload__')
+            if autoload:
+                table_kw = {'autoload': True}
+            else:
+                table_kw = {}
+            cols = []
+            for key, c in our_stuff.iteritems():
+                if isinstance(c, ColumnProperty):
+                    for col in c.columns:
+                        if isinstance(col, Column) and col.table is None:
+                            _undefer_column_name(key, col)
+                            cols.append(col)
+                elif isinstance(c, Column):
+                    _undefer_column_name(key, c)
+                    cols.append(c)
+            cls.__table__ = table = Table(tablename, cls.metadata,
+                                          *cols, **table_kw)
+    else:
+        table = cls.__table__
+
+    mapper_args = getattr(cls, '__mapper_args__', {})
+    if 'inherits' not in mapper_args:
+        inherits = cls.__mro__[1]
+        inherits = cls._decl_class_registry.get(inherits.__name__, None)
+        if inherits:
+            mapper_args['inherits'] = inherits
+            if not mapper_args.get('concrete', False) and table:
+                # figure out the inherit condition with relaxed rules
+                # about nonexistent tables, to allow for ForeignKeys to
+                # not-yet-defined tables (since we know for sure that our
+                # parent table is defined within the same MetaData)
+                mapper_args['inherit_condition'] = sql_util.join_condition(
+                    inherits.__table__, table,
+                    ignore_nonexistent_tables=True)
+
+    if hasattr(cls, '__mapper_cls__'):
+        mapper_cls = util.unbound_method_to_callable(cls.__mapper_cls__)
+    else:
+        mapper_cls = mapper
+
+    cls.__mapper__ = mapper_cls(cls, table, properties=our_stuff,
+                                **mapper_args)
 
 class DeclarativeMeta(type):
     def __init__(cls, classname, bases, dict_):
         if '_decl_class_registry' in cls.__dict__:
             return type.__init__(cls, classname, bases, dict_)
-
-        cls._decl_class_registry[classname] = cls
-        our_stuff = util.OrderedDict()
-        for k in dict_:
-            value = dict_[k]
-            if (isinstance(value, tuple) and len(value) == 1 and
-                isinstance(value[0], (Column, MapperProperty))):
-                util.warn("Ignoring declarative-like tuple value of attribute "
-                          "%s: possibly a copy-and-paste error with a comma "
-                          "left at the end of the line?" % k)
-                continue
-            if not isinstance(value, (Column, MapperProperty)):
-                continue
-            prop = _deferred_relation(cls, value)
-            our_stuff[k] = prop
-
-        # set up attributes in the order they were created
-        our_stuff.sort(lambda x, y: cmp(our_stuff[x]._creation_order,
-                                        our_stuff[y]._creation_order))
-
-        table = None
-        if '__table__' not in cls.__dict__:
-            if '__tablename__' in cls.__dict__:
-                tablename = cls.__tablename__
-                autoload = cls.__dict__.get('__autoload__')
-                if autoload:
-                    table_kw = {'autoload': True}
-                else:
-                    table_kw = {}
-                cols = []
-                for key, c in our_stuff.iteritems():
-                    if isinstance(c, ColumnProperty):
-                        for col in c.columns:
-                            if isinstance(col, Column) and col.table is None:
-                                _undefer_column_name(key, col)
-                                cols.append(col)
-                    elif isinstance(c, Column):
-                        _undefer_column_name(key, c)
-                        cols.append(c)
-                cls.__table__ = table = Table(tablename, cls.metadata,
-                                              *cols, **table_kw)
-        else:
-            table = cls.__table__
-
-        mapper_args = getattr(cls, '__mapper_args__', {})
-        if 'inherits' not in mapper_args:
-            inherits = cls.__mro__[1]
-            inherits = cls._decl_class_registry.get(inherits.__name__, None)
-            if inherits:
-                mapper_args['inherits'] = inherits
-                if not mapper_args.get('concrete', False) and table:
-                    # figure out the inherit condition with relaxed rules
-                    # about nonexistent tables, to allow for ForeignKeys to
-                    # not-yet-defined tables (since we know for sure that our
-                    # parent table is defined within the same MetaData)
-                    mapper_args['inherit_condition'] = sql_util.join_condition(
-                        inherits.__table__, table,
-                        ignore_nonexistent_tables=True)
-
-        if hasattr(cls, '__mapper_cls__'):
-            mapper_cls = util.unbound_method_to_callable(cls.__mapper_cls__)
-        else:
-            mapper_cls = mapper
-
-        cls.__mapper__ = mapper_cls(cls, table, properties=our_stuff,
-                                    **mapper_args)
+        
+        _as_declarative(cls, classname, dict_)
         return type.__init__(cls, classname, bases, dict_)
 
     def __setattr__(cls, key, value):
@@ -337,11 +351,11 @@ def comparable_using(comparator_factory):
         return comparable_property(comparator_factory, fn)
     return decorate
 
-def declarative_base(engine=None, metadata=None, mapper=None):
+def declarative_base(engine=None, metadata=None, mapper=None, cls=object):
     lcl_metadata = metadata or MetaData()
     if engine:
         lcl_metadata.bind = engine
-    class Base(object):
+    class Base(cls):
         __metaclass__ = DeclarativeMeta
         metadata = lcl_metadata
         if mapper:
index b7d6199b82527239364e36b2a445afbb3571f66f..608f3f7344917189ba676e98f687af7358da44e1 100644 (file)
@@ -956,7 +956,7 @@ class Query(object):
             if start < 0 or stop < 0:
                 return list(self)[item]
             else:
-                res = self.slice_(start, stop)
+                res = self.slice(start, stop)
                 if step is not None:
                     return list(res)[None:None:item.step]
                 else:
@@ -964,7 +964,7 @@ class Query(object):
         else:
             return list(self[item:item+1])[0]
     
-    def slice_(self, start, stop):
+    def slice(self, start, stop):
         """apply LIMIT/OFFSET to the ``Query`` based on a range and return the newly resulting ``Query``."""
         
         if start is not None and stop is not None:
@@ -974,7 +974,7 @@ class Query(object):
             self._limit = stop
         elif start is not None and stop is None:
             self._offset = (self._offset or 0) + start
-    slice_ = _generative(__no_statement_condition)(slice_)
+    slice = _generative(__no_statement_condition)(slice)
         
     def limit(self, limit):
         """Apply a ``LIMIT`` to the query and return the newly resulting
index ca91f98fce8862bbc0f52e7a8f503c325235df4d..d5ea6df47748a05150aba7b9f7c286dd146ed046 100644 (file)
@@ -95,6 +95,14 @@ class DeclarativeTest(testing.TestBase, testing.AssertsExecutionResults):
                 foo = sa.orm.column_property(User.id == 5)
         self.assertRaises(sa.exc.InvalidRequestError, go)
 
+    def test_custom_base(self):
+        class MyBase(object):
+            def foobar(self):
+                return "foobar"
+        Base = decl.declarative_base(cls=MyBase)
+        assert hasattr(Base, 'metadata')
+        assert Base().foobar() == "foobar"
+        
     def test_add_prop(self):
         class User(Base, ComparableEntity):
             __tablename__ = 'users'
@@ -135,7 +143,40 @@ class DeclarativeTest(testing.TestBase, testing.AssertsExecutionResults):
         eq_(a1, Address(email='two'))
         eq_(a1.user, User(name='u1'))
 
+    def test_as_declarative(self):
+        class User(ComparableEntity):
+            __tablename__ = 'users'
+
+            id = Column('id', Integer, primary_key=True)
+            name = Column('name', String(50))
+            addresses = relation("Address", backref="user")
+
+        class Address(ComparableEntity):
+            __tablename__ = 'addresses'
+
+            id = Column('id', Integer, primary_key=True)
+            email = Column('email', String(50))
+            user_id = Column('user_id', Integer, ForeignKey('users.id'))
+        
+        reg = {}
+        decl.instrument_declarative(User, reg, Base.metadata)
+        decl.instrument_declarative(Address, reg, Base.metadata)
+        Base.metadata.create_all()
+        
+        u1 = User(name='u1', addresses=[
+            Address(email='one'),
+            Address(email='two'),
+        ])
+        sess = create_session()
+        sess.save(u1)
+        sess.flush()
+        sess.clear()
 
+        eq_(sess.query(User).all(), [User(name='u1', addresses=[
+            Address(email='one'),
+            Address(email='two'),
+        ])])
+        
     def test_custom_mapper(self):
         class MyExt(sa.orm.MapperExtension):
             def create_instance(self):
index 7a51c3f7e3a839b4c99ccf7905cbf978cab9a02b..eb7a0f3d37205ded65619b5aa911debbf622b838 100644 (file)
@@ -489,7 +489,7 @@ class FromSelfTest(QueryTest):
 
         assert [User(id=8), User(id=9)] == create_session().query(User).filter(User.id.in_([8,9]))._from_self().all()
 
-        assert [User(id=8), User(id=9)] == create_session().query(User).slice_(1,3)._from_self().all()
+        assert [User(id=8), User(id=9)] == create_session().query(User).slice(1,3)._from_self().all()
         assert [User(id=8)] == list(create_session().query(User).filter(User.id.in_([8,9]))._from_self()[0:1])
     
     def test_join(self):
@@ -1123,7 +1123,7 @@ class MixedEntitiesTest(QueryTest):
         q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(User.id, Address.id).values(User.name, Address.email_address)
         self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com')])
         
-        q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(desc(Address.email_address)).slice_(1, 3).values(User.name, Address.email_address)
+        q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(desc(Address.email_address)).slice(1, 3).values(User.name, Address.email_address)
         self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@lala.com')])
         
         adalias = aliased(Address)