]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- reorganize the usage of __mapper_args__ so that it's only
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 18 May 2012 14:55:28 +0000 (10:55 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 18 May 2012 14:55:28 +0000 (10:55 -0400)
called after the __prepare__() step, if any, so that everything to
do with the mapping occurs after the table is reflected.

lib/sqlalchemy/ext/declarative.py
test/ext/test_declarative_reflection.py

index 2e9012384bf6e92819e85da4436ca7ca3b0e7e5c..6ed6d96f43605fda4dbcdbbb64f4ca1a62523874 100755 (executable)
@@ -300,7 +300,7 @@ as below where we assign the ``name`` column to the attribute ``_name``, generat
 a synonym for ``name``::
 
     from sqlalchemy.ext.declarative import synonym_for
-    
+
     class MyClass(Base):
         __table__ = Table('my_table', Base.metadata,
             Column('id', Integer, primary_key=True),
@@ -516,14 +516,14 @@ loader for the mapper after all subclasses have been declared.
 An abstract base can be declared using the :class:`.AbstractConcreteBase` class::
 
     from sqlalchemy.ext.declarative import AbstractConcreteBase
-    
+
     class Employee(AbstractConcreteBase, Base):
         pass
 
 To have a concrete ``employee`` table, use :class:`.ConcreteBase` instead::
 
     from sqlalchemy.ext.declarative import ConcreteBase
-    
+
     class Employee(ConcreteBase, Base):
         __tablename__ = 'employee'
         employee_id = Column(Integer, primary_key=True)
@@ -531,7 +531,7 @@ To have a concrete ``employee`` table, use :class:`.ConcreteBase` instead::
         __mapper_args__ = {
                         'polymorphic_identity':'employee', 
                         'concrete':True}
-    
+
 
 Either ``Employee`` base can be used in the normal fashion::
 
@@ -572,7 +572,7 @@ mappings are declared.   An example of some commonly mixed-in
 idioms is below::
 
     from sqlalchemy.ext.declarative import declared_attr
-    
+
     class MyMixin(object):
 
         @declared_attr
@@ -620,13 +620,13 @@ is achieved using the ``cls`` argument of the :func:`.declarative_base` function
         @declared_attr
         def __tablename__(cls):
             return cls.__name__.lower()
-            
+
         __table_args__ = {'mysql_engine': 'InnoDB'}
 
         id =  Column(Integer, primary_key=True)
 
     from sqlalchemy.ext.declarative import declarative_base
-    
+
     Base = declarative_base(cls=Base)
 
     class MyModel(Base):
@@ -954,17 +954,17 @@ just from the special class::
 
     class SomeAbstractBase(Base):
         __abstract__ = True
-        
+
         def some_helpful_method(self):
             ""
-            
+
         @declared_attr
         def __mapper_args__(cls):
             return {"helpful mapper arguments":True}
 
     class MyMappedClass(SomeAbstractBase):
         ""
-        
+
 One possible use of ``__abstract__`` is to use a distinct :class:`.MetaData` for different
 bases::
 
@@ -979,7 +979,7 @@ bases::
         metadata = MetaData()
 
 Above, classes which inherit from ``DefaultBase`` will use one :class:`.MetaData` as the 
-registry of tables, and those which inherit from ``OtherBase`` will use a different one.  
+registry of tables, and those which inherit from ``OtherBase`` will use a different one.
 The tables themselves can then be created perhaps within distinct databases::
 
     DefaultBase.metadata.create_all(some_engine)
@@ -1069,7 +1069,7 @@ def _as_declarative(cls, classname, dict_):
     column_copies = {}
     potential_columns = {}
 
-    mapper_args = {}
+    mapper_args_fn = None
     table_args = inherited_table_args = None
     tablename = None
     parent_columns = ()
@@ -1095,11 +1095,14 @@ def _as_declarative(cls, classname, dict_):
 
         for name,obj in vars(base).items():
             if name == '__mapper_args__':
-                if not mapper_args and (
+                if not mapper_args_fn and (
                                         not class_mapped or 
                                         isinstance(obj, declarative_props)
                                     ):
-                    mapper_args = cls.__mapper_args__
+                    # don't even invoke __mapper_args__ until
+                    # after we've determined everything about the
+                    # mapped table.
+                    mapper_args_fn = lambda: cls.__mapper_args__
             elif name == '__tablename__':
                 if not tablename and (
                                         not class_mapped or 
@@ -1164,13 +1167,6 @@ def _as_declarative(cls, classname, dict_):
     if inherited_table_args and not tablename:
         table_args = None
 
-    # make sure that column copies are used rather 
-    # than the original columns from any mixins
-    for k in ('version_id_col', 'polymorphic_on',):
-        if k in mapper_args:
-            v = mapper_args[k]
-            mapper_args[k] = column_copies.get(v,v)
-
     if classname in cls._decl_class_registry:
         util.warn("The classname %r is already in the registry of this"
                   " declarative base, mapped to %r" % (
@@ -1181,6 +1177,11 @@ def _as_declarative(cls, classname, dict_):
     our_stuff = util.OrderedDict()
 
     for k in dict_:
+
+        # TODO: improve this ?  all dunders ?
+        if k in ('__table__', '__tablename__', '__mapper_args__'):
+            continue
+
         value = dict_[k]
         if isinstance(value, declarative_props):
             value = getattr(cls, k)
@@ -1206,17 +1207,17 @@ def _as_declarative(cls, classname, dict_):
     our_stuff.sort(key=lambda key: our_stuff[key]._creation_order)
 
     # extract columns from the class dict
-    cols = set()
+    declared_columns = set()
     for key, c in our_stuff.iteritems():
         if isinstance(c, (ColumnProperty, CompositeProperty)):
             for col in c.columns:
                 if isinstance(col, Column) and \
                     col.table is None:
                     _undefer_column_name(key, col)
-                    cols.add(col)
+                    declared_columns.add(col)
         elif isinstance(c, Column):
             _undefer_column_name(key, c)
-            cols.add(c)
+            declared_columns.add(c)
             # if the column is the same name as the key, 
             # remove it from the explicit properties dict.
             # the normal rules for assigning column-based properties
@@ -1224,7 +1225,7 @@ def _as_declarative(cls, classname, dict_):
             # in multi-column ColumnProperties.
             if key == c.key:
                 del our_stuff[key]
-    cols = sorted(cols, key=lambda c:c._creation_order)
+    declared_columns = sorted(declared_columns, key=lambda c:c._creation_order)
     table = None
 
     if hasattr(cls, '__table_cls__'):
@@ -1250,38 +1251,38 @@ def _as_declarative(cls, classname, dict_):
                 table_kw['autoload'] = True
 
             cls.__table__ = table = table_cls(tablename, cls.metadata,
-                                          *(tuple(cols) + tuple(args)),
+                                          *(tuple(declared_columns) + tuple(args)),
                                            **table_kw)
     else:
         table = cls.__table__
-        if cols:
-            for c in cols:
+        if declared_columns:
+            for c in declared_columns:
                 if not table.c.contains_column(c):
                     raise exc.ArgumentError(
                         "Can't add additional column %r when "
                         "specifying __table__" % c.key
                     )
 
-    if 'inherits' not in mapper_args:
-        for c in cls.__bases__:
-            if _declared_mapping_info(c) is not None:
-                mapper_args['inherits'] = c
-                break
-
     if hasattr(cls, '__mapper_cls__'):
         mapper_cls = util.unbound_method_to_callable(cls.__mapper_cls__)
     else:
         mapper_cls = mapper
 
-    if table is None and 'inherits' not in mapper_args:
+    for c in cls.__bases__:
+        if _declared_mapping_info(c) is not None:
+            inherits = c
+            break
+    else:
+        inherits = None
+
+    if table is None and inherits is None:
         raise exc.InvalidRequestError(
             "Class %r does not have a __table__ or __tablename__ "
             "specified and does not inherit from an existing "
             "table-mapped class." % cls
             )
-
-    elif 'inherits' in mapper_args and not mapper_args.get('concrete', False):
-        inherited_mapper = _declared_mapping_info(mapper_args['inherits'])
+    elif inherits:
+        inherited_mapper = _declared_mapping_info(inherits)
         inherited_table = inherited_mapper.local_table
 
         if table is None:
@@ -1294,7 +1295,7 @@ def _as_declarative(cls, classname, dict_):
                     )
 
             # add any columns declared here to the inherited table.
-            for c in cols:
+            for c in declared_columns:
                 if c.primary_key:
                     raise exc.ArgumentError(
                         "Can't place primary key columns on an inherited "
@@ -1308,67 +1309,103 @@ def _as_declarative(cls, classname, dict_):
                     )
                 inherited_table.append_column(c)
 
-        # single or joined inheritance
-        # exclude any cols on the inherited table which are not mapped on the
-        # parent class, to avoid
-        # mapping columns specific to sibling/nephew classes
-        inherited_mapper = _declared_mapping_info(mapper_args['inherits'])
-        inherited_table = inherited_mapper.local_table
-
-        if 'exclude_properties' not in mapper_args:
-            mapper_args['exclude_properties'] = exclude_properties = \
-                set([c.key for c in inherited_table.c
-                     if c not in inherited_mapper._columntoproperty])
-            exclude_properties.difference_update([c.key for c in cols])
-
-        # look through columns in the current mapper that 
-        # are keyed to a propname different than the colname
-        # (if names were the same, we'd have popped it out above,
-        # in which case the mapper makes this combination).
-        # See if the superclass has a similar column property.
-        # If so, join them together.
-        for k, col in our_stuff.items():
-            if not isinstance(col, expression.ColumnElement):
-                continue
-            if k in inherited_mapper._props:
-                p = inherited_mapper._props[k]
-                if isinstance(p, ColumnProperty):
-                    # note here we place the superclass column
-                    # first.  this corresponds to the 
-                    # append() in mapper._configure_property().
-                    # change this ordering when we do [ticket:1892]
-                    our_stuff[k] = p.columns + [col]
-
     mt = _MapperConfig(mapper_cls, 
-                       cls, table, 
-                       properties=our_stuff, 
-                       **mapper_args)
+                       cls, table,
+                       inherits,
+                       declared_columns, 
+                       column_copies,
+                       our_stuff, 
+                       mapper_args_fn)
     if not hasattr(cls, '__prepare__'):
         mt.map()
 
 class _MapperConfig(object):
     configs = util.OrderedDict()
 
-    def __init__(self, mapper_cls, cls, table, **mapper_args):
+    def __init__(self, mapper_cls, 
+                        cls, 
+                        table, 
+                        inherits,
+                        declared_columns,
+                        column_copies,
+                        properties, mapper_args_fn):
         self.mapper_cls = mapper_cls
         self.cls = cls
         self.local_table = table
-        self.mapper_args = mapper_args
-        self._columntoproperty = set()
-        if table is not None:
-            self._columntoproperty.update(table.c)
+        self.inherits = inherits
+        self.properties = properties
+        self.mapper_args_fn = mapper_args_fn
+        self.declared_columns = declared_columns
+        self.column_copies = column_copies
         self.configs[cls] = self
 
-    @property
-    def args(self):
-        return self.cls, self.local_table, self.mapper_args
+    def _prepare_mapper_arguments(self):
+        cls = self.cls
+        table = self.local_table
+        properties = self.properties
+        if self.mapper_args_fn:
+            mapper_args = self.mapper_args_fn()
+        else:
+            mapper_args = {}
+
+        # make sure that column copies are used rather 
+        # than the original columns from any mixins
+        for k in ('version_id_col', 'polymorphic_on',):
+            if k in mapper_args:
+                v = mapper_args[k]
+                mapper_args[k] = self.column_copies.get(v,v)
+
+        assert 'inherits' not in mapper_args, \
+            "Can't specify 'inherits' explicitly with declarative mappings"
+
+        if self.inherits:
+            mapper_args['inherits'] = self.inherits
+
+        if self.inherits and not mapper_args.get('concrete', False):
+            # single or joined inheritance
+            # exclude any cols on the inherited table which are 
+            # not mapped on the parent class, to avoid
+            # mapping columns specific to sibling/nephew classes
+            inherited_mapper = _declared_mapping_info(self.inherits)
+            inherited_table = inherited_mapper.local_table
+
+            if 'exclude_properties' not in mapper_args:
+                mapper_args['exclude_properties'] = exclude_properties = \
+                    set([c.key for c in inherited_table.c
+                         if c not in inherited_mapper._columntoproperty])
+                exclude_properties.difference_update(
+                        [c.key for c in self.declared_columns])
+
+            # look through columns in the current mapper that 
+            # are keyed to a propname different than the colname
+            # (if names were the same, we'd have popped it out above,
+            # in which case the mapper makes this combination).
+            # See if the superclass has a similar column property.
+            # If so, join them together.
+            for k, col in properties.items():
+                if not isinstance(col, expression.ColumnElement):
+                    continue
+                if k in inherited_mapper._props:
+                    p = inherited_mapper._props[k]
+                    if isinstance(p, ColumnProperty):
+                        # note here we place the superclass column
+                        # first.  this corresponds to the 
+                        # append() in mapper._configure_property().
+                        # change this ordering when we do [ticket:1892]
+                        properties[k] = p.columns + [col]
+
+        result_mapper_args = mapper_args.copy()
+        result_mapper_args['properties'] = properties
+        return result_mapper_args
 
     def map(self):
         self.configs.pop(self.cls, None)
+        mapper_args = self._prepare_mapper_arguments()
+
         self.cls.__mapper__ = self.mapper_cls(
             self.cls,
             self.local_table,
-            **self.mapper_args
+            **mapper_args
         )
 
 class DeclarativeMeta(type):
@@ -1535,7 +1572,7 @@ class declared_attr(property):
     a mapped property or special declarative member name.
 
     .. note:: 
-    
+
        @declared_attr is available as 
        ``sqlalchemy.util.classproperty`` for SQLAlchemy versions
        0.6.2, 0.6.3, 0.6.4.
@@ -1660,7 +1697,7 @@ def declarative_base(bind=None, metadata=None, mapper=None, cls=object,
       and others.  Allows two or more declarative base classes
       to share the same registry of class names for simplified 
       inter-base relationships.
-      
+
     :param metaclass:
       Defaults to :class:`.DeclarativeMeta`.  A metaclass or __metaclass__
       compatible callable to use as the meta type of the generated
@@ -1693,7 +1730,7 @@ def _undefer_column_name(key, column):
 
 class ConcreteBase(object):
     """A helper class for 'concrete' declarative mappings.
-    
+
     :class:`.ConcreteBase` will use the :func:`.polymorphic_union`
     function automatically, against all tables mapped as a subclass
     to this class.   The function is called via the
@@ -1703,7 +1740,7 @@ class ConcreteBase(object):
     :class:`.ConcreteBase` produces a mapped
     table for the class itself.  Compare to :class:`.AbstractConcreteBase`,
     which does not.
-    
+
     Example::
 
         from sqlalchemy.ext.declarative import ConcreteBase
@@ -1747,17 +1784,17 @@ class ConcreteBase(object):
 
 class AbstractConcreteBase(ConcreteBase):
     """A helper class for 'concrete' declarative mappings.
-    
+
     :class:`.AbstractConcreteBase` will use the :func:`.polymorphic_union`
     function automatically, against all tables mapped as a subclass
     to this class.   The function is called via the
     ``__declare_last__()`` function, which is essentially
     a hook for the :func:`.MapperEvents.after_configured` event.
-    
+
     :class:`.AbstractConcreteBase` does not produce a mapped
     table for the class itself.  Compare to :class:`.ConcreteBase`,
     which does.
-    
+
     Example::
 
         from sqlalchemy.ext.declarative import ConcreteBase
@@ -1807,7 +1844,7 @@ class AbstractConcreteBase(ConcreteBase):
 class DeferredReflection(object):
     """A helper class for construction of mappings based on 
     a deferred reflection step.
-    
+
     Normally, declarative can be used with reflection by
     setting a :class:`.Table` object using autoload=True
     as the ``__table__`` attribute on a declarative class.
@@ -1816,22 +1853,22 @@ class DeferredReflection(object):
     at the point at which a normal declarative mapping is
     constructed, meaning the :class:`.Engine` must be available
     at class declaration time.
-    
+
     The :class:`.DeferredReflection` mixin moves the construction
     of mappers to be at a later point, after a specific
     method is called which first reflects all :class:`.Table`
     objects created so far.   Classes can define it as such::
-    
+
         from sqlalchemy.ext.declarative import declarative_base, DeferredReflection
         Base = declarative_base()
-        
+
         class MyClass(DeferredReflection, Base):
             __tablename__ = 'mytable'
-        
+
     Above, ``MyClass`` is not yet mapped.   After a series of
     classes have been defined in the above fashion, all tables
     can be reflected and mappings created using :meth:`.DeferredReflection.prepare`::
-    
+
         engine = create_engine("someengine://...")
         DeferredReflection.prepare(engine)
 
@@ -1843,32 +1880,32 @@ class DeferredReflection(object):
     that use more than one engine.  For example, if an application
     has two engines, you might use two bases, and prepare each
     separately, e.g.::
-    
+
         class ReflectedOne(DeferredReflection, Base):
             __abstract__ = True
 
         class ReflectedTwo(DeferredReflection, Base):
             __abstract__ = True
-        
+
         class MyClass(ReflectedOne):
             __tablename__ = 'mytable'
-        
+
         class MyOtherClass(ReflectedOne):
             __tablename__ = 'myothertable'
 
         class YetAnotherClass(ReflectedTwo):
             __tablename__ = 'yetanothertable'
-        
+
         # ... etc.
-        
+
     Above, the class hierarchies for ``ReflectedOne`` and
     ``ReflectedTwo`` can be configured separately::
-    
+
         ReflectedOne.prepare(engine_one)
         ReflectedTwo.prepare(engine_two)
-    
+
     .. versionadded:: 0.8
-    
+
     """
     @classmethod
     def prepare(cls, engine):
@@ -1877,12 +1914,11 @@ class DeferredReflection(object):
         to_map = [m for m in _MapperConfig.configs.values()
                     if issubclass(m.cls, cls)]
         for thingy in to_map:
-            cls.__prepare__(thingy.args, engine)
+            cls.__prepare__(thingy.local_table, engine)
             thingy.map()
 
     @classmethod
-    def __prepare__(cls, mapper_args, engine):
-        cls, local_table, args = mapper_args
+    def __prepare__(cls, local_table, engine):
         # autoload Table, which is already
         # present in the metadata.  This
         # will fill in db-loaded columns
index df95001899b358777c8bf46b3a93d325f40da892..38dd7cec3abf4c7ef595a22e8c9276c254c612e6 100644 (file)
@@ -235,6 +235,37 @@ class DeferredReflectionTest(DeferredReflectBase):
         decl.DeferredReflection.prepare(testing.db)
         self._roundtrip()
 
+    def test_mapper_args_deferred(self):
+        """test that __mapper_args__ is not called until *after* table reflection"""
+
+        class User(decl.DeferredReflection, fixtures.ComparableEntity, 
+                            Base):
+            __tablename__ = 'users'
+
+            @decl.declared_attr
+            def __mapper_args__(cls):
+                return {
+                    "order_by":cls.__table__.c.name
+                }
+
+        decl.DeferredReflection.prepare(testing.db)
+        sess = Session()
+        sess.add_all([
+            User(name='G'),
+            User(name='Q'),
+            User(name='A'),
+            User(name='C'),
+        ])
+        sess.commit()
+        eq_(
+            sess.query(User).all(),
+            [
+                User(name='A'),
+                User(name='C'),
+                User(name='G'),
+                User(name='Q'),
+            ]
+        )
 
 class DeferredInhReflectBase(DeferredReflectBase):
     def _roundtrip(self):