]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fixes for some nasty edge cases when usng descriptors to compute special attributes
authorChris Withers <chris@simplistix.co.uk>
Tue, 2 Mar 2010 10:17:31 +0000 (10:17 +0000)
committerChris Withers <chris@simplistix.co.uk>
Tue, 2 Mar 2010 10:17:31 +0000 (10:17 +0000)
lib/sqlalchemy/ext/declarative.py
test/ext/test_declarative.py

index 12b79c79651669215d772ac6e9133a1714f6e7e0..fe45e6c175872a80d5e9220caa80aac4ca28eafc 100644 (file)
@@ -528,9 +528,10 @@ def instrument_declarative(cls, registry, metadata):
     
 def _as_declarative(cls, classname, dict_):
 
-    # this spelling enables these attributes to be descriptors
-    mapper_args = '__mapper_args__' in dict_ and cls.__mapper_args__ or {}
-    table_args = '__table_args__' in dict_ and cls.__table_args__ or None
+    # doing it this way enables these attributes to be descriptors,
+    # see below...
+    get_mapper_args = '__mapper_args__' in dict_
+    get_table_args = '__table_args__' in dict_
     
     # dict_ will be a dictproxy, which we can't write to, and we need to!
     dict_ = dict(dict_)
@@ -544,12 +545,20 @@ def _as_declarative(cls, classname, dict_):
                 obj = getattr(base,name)
                 if isinstance(obj, Column):
                     dict_[name]=column_copies[obj]=obj.copy()
-            mapper_args = mapper_args or getattr(base,'__mapper_args__',mapper_args)
-            table_args = table_args or getattr(base,'__table_args__',None)
+            get_mapper_args = get_mapper_args or getattr(base,'__mapper_args__',None)
+            get_table_args = get_table_args or getattr(base,'__table_args__',None)
             tablename = getattr(base,'__tablename__',None)
             if tablename:
+                # subtle: if tablename is a descriptor here, we actually
+                # put the wrong value in, but it serves as a marker to get
+                # the right value value...
                 dict_['__tablename__']=tablename
 
+    # now that we know whether or not to get these, get them from the class
+    # if we should, enabling them to be decorators
+    mapper_args = get_mapper_args and cls.__mapper_args__ or {}
+    table_args = get_table_args and cls.__table_args__ or None
+    
     # make sure that column copies are used rather than the original columns
     # from any mixins
     for k, v in mapper_args.iteritems():
@@ -595,6 +604,8 @@ def _as_declarative(cls, classname, dict_):
     table = None
     if '__table__' not in dict_:
         if '__tablename__' in dict_:
+            # see above: if __tablename__ is a descriptor, this
+            # means we get the right value used!
             tablename = cls.__tablename__
             
             if isinstance(table_args, dict):
index a22a9581ea436bbaa86cb322a9d48fb4523096b0..3d11c0b4dcbe89f93b1678ddfde5cf861459e34c 100644 (file)
@@ -1900,6 +1900,19 @@ class DeclarativeMixinTest(DeclarativeTestBase):
 
         eq_(MyModel.__table__.kwargs,{'mysql_engine': 'InnoDB'})
     
+    def test_table_args_inherited_descriptor(self):
+        
+        class MyMixin:
+            @classproperty
+            def __table_args__(cls):
+                return {'info':cls.__name__} 
+
+        class MyModel(Base,MyMixin):
+            __tablename__='test'
+            id =  Column(Integer, primary_key=True)
+
+        eq_(MyModel.__table__.info,'MyModel')
+    
     def test_table_args_inherited_single_table_inheritance(self):
         
         class MyMixin:
@@ -1967,6 +1980,23 @@ class DeclarativeMixinTest(DeclarativeTestBase):
         eq_(MyModel.__mapper__.always_refresh,True)
     
     
+    def test_mapper_args_inherited_descriptor(self):
+        
+        class MyMixin:
+            @classproperty
+            def __mapper_args__(cls):
+                # tenuous, but illustrates the problem!
+                if cls.__name__=='MyModel':
+                    return dict(always_refresh=True)
+                else:
+                    return dict(always_refresh=False)
+
+        class MyModel(Base,MyMixin):
+            __tablename__='test'
+            id =  Column(Integer, primary_key=True)
+
+        eq_(MyModel.__mapper__.always_refresh,True)
+    
     def test_mapper_args_polymorphic_on_inherited(self):
 
         class MyMixin: