]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
further refinement to the inheritance "descriptor" detection such that
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 1 Aug 2008 17:13:31 +0000 (17:13 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 1 Aug 2008 17:13:31 +0000 (17:13 +0000)
local columns will still override superclass descriptors.

lib/sqlalchemy/orm/mapper.py
test/orm/inheritance/basic.py

index d7b7fbb7c0255479c020c0bb1e470fa6f45d7bfe..a0053fe267a8c911cf4a56c35b76219cc3960415 100644 (file)
@@ -639,18 +639,25 @@ class Mapper(object):
 
             return getattr(getattr(cls, clskey), key)
 
-    def _should_exclude(self, name):
+    def _should_exclude(self, name, local):
         """determine whether a particular property should be implicitly present on the class.
         
         This occurs when properties are propagated from an inherited class, or are 
         applied from the columns present in the mapped table.
         
         """
-        # check for an existing descriptor
-        if getattr(self.class_, name, None) \
-            and hasattr(getattr(self.class_, name), '__get__'):
-            return True
         
+        # check for descriptors, either local or from
+        # an inherited class
+        if local:
+            if self.class_.__dict__.get(name, None)\
+                and hasattr(self.class_.__dict__[name], '__get__'):
+                return True
+        else:
+            if getattr(self.class_, name, None)\
+                and hasattr(getattr(self.class_, name), '__get__'):
+                return True
+
         if (self.include_properties is not None and
             name not in self.include_properties):
             self.__log("not including property %s" % (name))
@@ -681,7 +688,7 @@ class Mapper(object):
         # pull properties from the inherited mapper if any.
         if self.inherits:
             for key, prop in self.inherits.__props.iteritems():
-                if key not in self.__props and not self._should_exclude(key):
+                if key not in self.__props and not self._should_exclude(key, local=False):
                     self._adapt_inherited_property(key, prop)
 
         # create properties for each column in the mapped table,
@@ -690,9 +697,6 @@ class Mapper(object):
             if column in self._columntoproperty:
                 continue
                 
-            if self._should_exclude(column.key):
-                continue
-
             column_key = (self.column_prefix or '') + column.key
 
             # adjust the "key" used for this column to that
@@ -700,14 +704,15 @@ class Mapper(object):
             for mapper in self.iterate_to_root():
                 if column in mapper._columntoproperty:
                     column_key = mapper._columntoproperty[column].key
-                
-            self._compile_property(column_key, column, init=False, setparent=True)
+            
+            if not self._should_exclude(column_key, local=self.local_table.c.contains_column(column)):
+                self._compile_property(column_key, column, init=False, setparent=True)
 
         # do a special check for the "discriminiator" column, as it may only be present
         # in the 'with_polymorphic' selectable but we need it for the base mapper
         if self.polymorphic_on and self.polymorphic_on not in self._columntoproperty:
             col = self.mapped_table.corresponding_column(self.polymorphic_on) or self.polymorphic_on
-            if self._should_exclude(col.key):
+            if self._should_exclude(col.key, local=False):
                 raise sa_exc.InvalidRequestError("Cannot exclude or override the discriminator column %r" % col.key)
             self._compile_property(col.key, ColumnProperty(col), init=False, setparent=True)
 
index decc867951679501384ce51a73ba291afcde8a04..20b4600b1cd7969f3e2a90689bea91154c9f4a8d 100644 (file)
@@ -873,39 +873,57 @@ class OverrideColKeyTest(ORMTest):
         sess.flush()
         assert sess.query(Sub).one().data == "im the data"
     
-    def test_two_levels(self):
+    def test_sub_columns_over_base_descriptors(self):
         class Base(object):
-            pass
+            @property
+            def subdata(self):
+                return "this is base"
 
         class Sub(Base):
-            @property
-            def data(self):
-                return "im sub"
+            pass
 
-        class SubSub(Sub):
-            @property
-            def data(self):
-                return "im sub sub"
-        
         mapper(Base, base)
         mapper(Sub, subtable, inherits=Base)
-        mapper(SubSub, inherits=Sub)
         
         sess = create_session()
-        s1 = Sub()
-        assert s1.data == "im sub"
-        s2 = SubSub()
-        assert s2.data == "im sub sub"
         b1 = Base()
-        b1.data="this is some data"
-        assert b1.data == "this is some data"
-        
-        sess.add_all([s1, s2, b1])
+        assert b1.subdata == "this is base"
+        s1 = Sub()
+        s1.subdata = "this is sub"
+        assert s1.subdata == "this is sub"
+
+        sess.add_all([s1, b1])
         sess.flush()
         sess.clear()
         
-        assert sess.query(Sub).get(s1.base_id).data == "im sub"
-        assert sess.query(SubSub).get(s2.base_id).data == "im sub sub"
+        assert sess.query(Base).get(b1.base_id).subdata == "this is base"
+        assert sess.query(Sub).get(s1.base_id).subdata == "this is sub"
+
+    def test_base_descriptors_over_base_cols(self):
+        class Base(object):
+            @property
+            def data(self):
+                return "this is base"
+
+        class Sub(Base):
+            pass
+
+        mapper(Base, base)
+        mapper(Sub, subtable, inherits=Base)
+
+        sess = create_session()
+        b1 = Base()
+        assert b1.data == "this is base"
+        s1 = Sub()
+        assert s1.data == "this is base"
+
+        sess.add_all([s1, b1])
+        sess.flush()
+        sess.clear()
+
+        assert sess.query(Base).get(b1.base_id).data == "this is base"
+        assert sess.query(Sub).get(s1.base_id).data == "this is base"
+
         
 class DeleteOrphanTest(ORMTest):
     def define_tables(self, metadata):