]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- declarative_base() takes optional kwarg "mapper", which
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 29 Mar 2008 14:41:41 +0000 (14:41 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 29 Mar 2008 14:41:41 +0000 (14:41 +0000)
is any callable/class/method that produces a mapper,
such as declarative_base(mapper=scopedsession.mapper).
This property can also be set on individual declarative
classes using the "__mapper_cls__" property.

CHANGES
lib/sqlalchemy/ext/declarative.py
lib/sqlalchemy/util.py
test/ext/declarative.py

diff --git a/CHANGES b/CHANGES
index 084c44294b9400d9f1be709ad1d9f1bfba87c866..d81c0a87b597a9fc476b3a85127864e5b1ecc1a4 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -109,7 +109,7 @@ CHANGES
       it when reflecting related tables.  This is stickier
       behavior than before which is why it's off by default.
 
-- extensions
+- declarative extension
     - The "synonym" function is now directly usable with
       "declarative".  Pass in the decorated property using the
       "descriptor" keyword argument, e.g.: somekey =
@@ -137,6 +137,12 @@ CHANGES
      - inheritance in declarative can be disabled when sending
        "inherits=None" to __mapper_args__.
 
+     - declarative_base() takes optional kwarg "mapper", which 
+       is any callable/class/method that produces a mapper,
+       such as declarative_base(mapper=scopedsession.mapper).
+       This property can also be set on individual declarative
+       classes using the "__mapper_cls__" property.
+     
 0.4.4
 ------
 - sql
index 62691a906a22ec8bd24374d58cafc5dc48b70599..d8576d79b172a1b0cc8be9cdc2268997465fede2 100644 (file)
@@ -162,6 +162,7 @@ from sqlalchemy.orm import synonym as _orm_synonym, mapper, comparable_property
 from sqlalchemy.orm.interfaces import MapperProperty
 from sqlalchemy.orm.properties import PropertyLoader, ColumnProperty
 from sqlalchemy import util, exceptions
+import types
 
 __all__ = ['declarative_base', 'synonym_for', 'comparable_using',
            'declared_synonym']
@@ -216,8 +217,12 @@ class DeclarativeMeta(type):
             inherits = cls.__mro__[1]
             inherits = cls._decl_class_registry.get(inherits.__name__, None)
             mapper_args['inherits'] = inherits
-            
-        cls.__mapper__ = mapper(cls, table, properties=our_stuff, **mapper_args)
+        
+        if hasattr(cls, '__mapper_cls__'):
+            mapper_cls = util.unbound_method_to_callable(cls.__mapper_cls__)
+        else:
+            mapper_cls = mapper
+        cls.__mapper__ = mapper_cls(cls, table, properties=our_stuff, **mapper_args)
         return type.__init__(cls, classname, bases, dict_)
 
     def __setattr__(cls, key, value):
@@ -294,13 +299,15 @@ def comparable_using(comparator_factory):
         return comparable_property(comparator_factory, fn)
     return decorate
 
-def declarative_base(engine=None, metadata=None):
+def declarative_base(engine=None, metadata=None, mapper=None):
     lcl_metadata = metadata or MetaData()
+    if engine:
+        lcl_metadata.bind = engine
     class Base(object):
         __metaclass__ = DeclarativeMeta
         metadata = lcl_metadata
-        if engine:
-            metadata.bind = engine
+        if mapper:
+            __mapper_cls__ = mapper
         _decl_class_registry = {}
         def __init__(self, **kwargs):
             for k in kwargs:
index 90332fdc0e35254b4006e79566a3cca1a2f686af..8451d28b560c2bc92f98cbac432e85b4248f5d4a 100644 (file)
@@ -279,6 +279,14 @@ def get_func_kwargs(func):
     """Return the full set of legal kwargs for the given `func`."""
     return inspect.getargspec(func)[0]
 
+def unbound_method_to_callable(func_or_cls):
+    """Adjust the incoming callable such that a 'self' argument is not required."""
+    
+    if isinstance(func_or_cls, types.MethodType) and not func_or_cls.im_self:
+        return func_or_cls.im_func
+    else:
+        return func_or_cls
+
 # from paste.deploy.converters
 def asbool(obj):
     if isinstance(obj, (str, unicode)):
index 5da2dded5ee18ff0928999b9a03adb64c5ee4ff6..c2f49138cc8f57b2dffe190e5348d57af05ae7a8 100644 (file)
@@ -2,6 +2,7 @@ import testenv; testenv.configure_for_tests()
 
 from sqlalchemy import *
 from sqlalchemy.orm import *
+from sqlalchemy.orm.interfaces import MapperExtension
 from sqlalchemy.ext.declarative import declarative_base, declared_synonym, \
                                        synonym_for, comparable_using
 from sqlalchemy import exceptions
@@ -135,6 +136,41 @@ class DeclarativeTest(TestBase, AssertsExecutionResults):
         self.assertEquals(a1, Address(email='two'))
         self.assertEquals(a1.user, User(name='u1'))
 
+    
+    def test_custom_mapper(self):
+        class MyExt(MapperExtension):
+            def create_instance(self):
+                return "CHECK"
+
+        def mymapper(cls, tbl, **kwargs):
+            kwargs['extension'] = MyExt()
+            return mapper(cls, tbl, **kwargs)
+
+        from sqlalchemy.orm.mapper import Mapper
+        class MyMapper(Mapper):
+            def __init__(self, *args, **kwargs):
+                kwargs['extension'] = MyExt()
+                Mapper.__init__(self, *args, **kwargs)
+
+        from sqlalchemy.orm import scoping
+        ss = scoping.ScopedSession(create_session)
+        ss.extension = MyExt()
+        ss_mapper = ss.mapper
+
+        for mapperfunc in (mymapper, MyMapper, ss_mapper):
+            base = declarative_base()
+            class Foo(base):
+                __tablename__ = 'foo'
+                __mapper_cls__ = mapperfunc
+                id = Column(Integer, primary_key=True)
+            assert Foo.__mapper__.compile().extension.create_instance() == 'CHECK'
+
+            base = declarative_base(mapper=mapperfunc)
+            class Foo(base):
+                __tablename__ = 'foo'
+                id = Column(Integer, primary_key=True)
+            assert Foo.__mapper__.compile().extension.create_instance() == 'CHECK'
+
 
     @testing.emits_warning('Ignoring declarative-like tuple value of '
                            'attribute id')