]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- fixes [ticket:185], join object determines primary key and removes
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 27 Mar 2007 22:06:36 +0000 (22:06 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 27 Mar 2007 22:06:36 +0000 (22:06 +0000)
columns that are FK's to other columns in the primary key collection.
- removed workaround code from query.py get()
- removed obsolete inheritance test from mapper
- added new get() test to inheritance.py for this particular issue
- ColumnCollection has nicer string method

lib/sqlalchemy/orm/query.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
test/orm/inheritance.py
test/orm/mapper.py

index 070fb4cb5ec12e06d6711d9557928b439262bd41..77499c2714343e7840ac0f167da3c9af3965fe7b 100644 (file)
@@ -774,17 +774,9 @@ class Query(object):
             ident = key[1]
         else:
             ident = util.to_list(ident)
-        i = 0
         params = {}
-        for primary_key in self.primary_key_columns:
+        for i, primary_key in enumerate(self.primary_key_columns):
             params[primary_key._label] = ident[i]
-            # if there are not enough elements in the given identifier, then
-            # use the previous identifier repeatedly.  this is a workaround for the issue
-            # in [ticket:185], where a mapper that uses joined table inheritance needs to specify
-            # all primary keys of the joined relationship, which includes even if the join is joining
-            # two primary key (and therefore synonymous) columns together, the usual case for joined table inheritance.
-            if len(ident) > i + 1:
-                i += 1
         try:
             statement = self.compile(self._get_clause, lockmode=lockmode)
             return self._select_statement(statement, params=params, populate_existing=reload, version_check=(lockmode is not None))[0]
index bd601ed80013ac16700768a899c109dd62be16d7..5ed95fabb5bf6f81fdcb33b2183d76bd75efbf67 100644 (file)
@@ -680,7 +680,7 @@ class ForeignKey(SchemaItem):
         """Return True if the given table is referenced by this ``ForeignKey``."""
 
         return table.corresponding_column(self.column, False) is not None
-
+    
     def _init_column(self):
         # ForeignKey inits its remote column as late as possible, so tables can
         # be defined without dependencies
index 6924a60ceb66164646d7e0241cde6aa084f62718..74f085cb1a9e669510bca19e5ea56f97824a1afe 100644 (file)
@@ -1075,15 +1075,24 @@ class ColumnCollection(util.OrderedProperties):
         super(ColumnCollection, self).__init__()
         [self.add(c) for c in cols]
 
+    def __str__(self):
+        return repr([str(c) for c in self])
+        
     def add(self, column):
         """Add a column to this collection.
 
         The key attribute of the column will be used as the hash key
         for this dictionary.
         """
-
         self[column.key] = column
-
+    
+    def remove(self, column):
+        del self[column.key]
+        
+    def extend(self, iter):
+        for c in iter:
+            self.add(c)
+            
     def __eq__(self, other):
         l = []
         for c in other:
@@ -1243,6 +1252,16 @@ class FromClause(Selectable):
         self._primary_key = ColumnCollection()
         self._foreign_keys = util.Set()
         self._orig_cols = {}
+        for co in self._adjusted_exportable_columns():
+            cp = self._proxy_column(co)
+            for ci in cp.orig_set:
+                self._orig_cols[ci] = cp
+        if self.oid_column is not None:
+            for ci in self.oid_column.orig_set:
+                self._orig_cols[ci] = self.oid_column
+    
+    def _adjusted_exportable_columns(self):
+        """return the list of ColumnElements represented within this FromClause's _exportable_columns"""
         export = self._exportable_columns()
         for column in export:
             try:
@@ -1250,13 +1269,8 @@ class FromClause(Selectable):
             except AttributeError:
                 continue
             for co in s.columns:
-                cp = self._proxy_column(co)
-                for ci in cp.orig_set:
-                    self._orig_cols[ci] = cp
-        if self.oid_column is not None:
-            for ci in self.oid_column.orig_set:
-                self._orig_cols[ci] = self.oid_column
-
+                yield co
+        
     def _exportable_columns(self):
         return []
 
@@ -1661,10 +1675,23 @@ class Join(FromClause):
         else:
             self.onclause = onclause
         self.isouter = isouter
-
+        self.__folded_equivalents = None
+        self._init_primary_key()
+        
     name = property(lambda s: "Join object on " + s.left.name + " " + s.right.name)
     encodedname = property(lambda s: s.name.encode('ascii', 'backslashreplace'))
-    
+
+    def _init_primary_key(self):
+        pkcol = util.Set()
+        for col in self._adjusted_exportable_columns():
+            if col.primary_key:
+                pkcol.add(col)
+        for col in list(pkcol):
+            for f in col.foreign_keys:
+                if f.column in pkcol:
+                    pkcol.remove(col)
+        self.primary_key.extend(pkcol)
+        
     def _locate_oid_column(self):
         return self.left.oid_column
 
@@ -1673,8 +1700,6 @@ class Join(FromClause):
 
     def _proxy_column(self, column):
         self._columns[column._label] = column
-        if column.primary_key:
-            self._primary_key.add(column)
         for f in column.foreign_keys:
             self._foreign_keys.add(f)
         return column
@@ -1706,6 +1731,8 @@ class Join(FromClause):
         return True
 
     def _get_folded_equivalents(self, equivs=None):
+        if self.__folded_equivalents is not None:
+            return self.__folded_equivalents
         if equivs is None:
             equivs = util.Set()
         class LocateEquivs(NoColumnVisitor):
@@ -1731,7 +1758,8 @@ class Join(FromClause):
                     used.add(c.name)
             else: 
                 collist.append(c)
-        return collist
+        self.__folded_equivalents = collist
+        return self.__folded_equivalents
         
     def select(self, whereclause = None, fold_equivalents=False, **kwargs):
         """Create a ``Select`` from this ``Join``.
@@ -1740,9 +1768,11 @@ class Join(FromClause):
           the WHERE criterion that will be sent to the ``select()`` function
           
         fold_equivalents
-          based on the join criterion of this ``Join``, do not include equivalent
-          columns in the column list of the resulting select.  this will recursively
-          apply to any joins directly nested by this one as well.
+          based on the join criterion of this ``Join``, do not include repeat
+          column names in the column list of the resulting select, for columns that
+          are calculated to be "equivalent" based on the join criterion of this
+          ``Join``. this will recursively apply to any joins directly nested by
+          this one as well.
           
         \**kwargs
           all other kwargs are sent to the underlying ``select()`` function
index d09e685b3901d251b7efe3fd5a31a8c14b497621..0aabdb9be2b35ee6c0dedebe65e051ad45378a42 100644 (file)
@@ -96,26 +96,41 @@ class InheritTest2(testbase.ORMTest):
             Column('foo_id', Integer, ForeignKey('foo.id')),
             Column('bar_id', Integer, ForeignKey('bar.bid')))
 
+    def testget(self):
+        class Foo(object):pass
+        def __init__(self, data=None):
+            self.data = data
+        class Bar(Foo):pass
+        
+        mapper(Foo, foo)
+        mapper(Bar, bar, inherits=Foo)
+        
+        b = Bar('somedata')
+        sess = create_session()
+        sess.save(b)
+        sess.flush()
+        sess.clear()
+        
+        # test that "bar.bid" does not need to be referenced in a get
+        # (ticket 185)
+        assert sess.query(Bar).get(b.id).id == b.id
+        
     def testbasic(self):
         class Foo(object): 
             def __init__(self, data=None):
                 self.data = data
-            def __str__(self):
-                return "Foo(%s)" % self.data
-            def __repr__(self):
-                return str(self)
 
         mapper(Foo, foo)
         class Bar(Foo):
-            def __str__(self):
-                return "Bar(%s)" % self.data
+            pass
 
         mapper(Bar, bar, inherits=Foo, properties={
             'foos': relation(Foo, secondary=foo_bar, lazy=False)
         })
         
         sess = create_session()
-        b = Bar('barfoo', _sa_session=sess)
+        b = Bar('barfoo')
+        sess.save(b)
         sess.flush()
 
         f1 = Foo('subfoo1')
index f5a4613c95c312324886d6de4575bdab8c9b9b54..942fdfdd9e0e1fd11d541e1c4fb1e235260af495 100644 (file)
@@ -626,84 +626,6 @@ class MapperTest(MapperSuperTest):
         q3 = sess.query(User).options(eagerload('orders.items.keywords'))
         u = q3.select()
         self.assert_sql_count(db, go, 2)
-        
-class InheritanceTest(MapperSuperTest):
-
-    def testinherits(self):
-        class _Order(object):
-            pass
-        ordermapper = mapper(_Order, orders)
-            
-        class _User(object):
-            pass
-        usermapper = mapper(_User, users, properties = dict(
-                orders = relation(ordermapper, lazy = False)
-            ))
-
-        class AddressUser(_User):
-            pass
-        mapper(AddressUser, addresses, inherits = usermapper)
-        
-        sess = create_session()
-        q = sess.query(AddressUser)    
-        l = q.select()
-        
-        jack = l[0]
-        self.assert_(jack.user_name=='jack')
-        jack.email_address = 'jack@gmail.com'
-        sess.flush()
-        sess.clear()
-        au = q.get_by(user_name='jack')
-        self.assert_(au.email_address == 'jack@gmail.com')
-
-    def testinherits2(self):
-        class _Order(object):
-            pass
-        class _Address(object):
-            pass
-        class AddressUser(_Address):
-            pass
-        ordermapper = mapper(_Order, orders)
-        addressmapper = mapper(_Address, addresses)
-        usermapper = mapper(AddressUser, users, inherits = addressmapper,
-            properties = {
-                'orders' : relation(ordermapper, lazy=False)
-            })
-        sess = create_session()
-        l = sess.query(usermapper).select()
-        jack = l[0]
-        self.assert_(jack.user_name=='jack')
-        jack.email_address = 'jack@gmail.com'
-        sess.flush()
-        sess.clear()
-        au = sess.query(usermapper).get_by(user_name='jack')
-        self.assert_(au.email_address == 'jack@gmail.com')
-
-    def testlazyoption(self):
-        """test that a lazy options gets created against its correct mapper when
-        using options with inheriting mappers"""
-        class _Order(object):
-            pass
-        class _User(object):
-            pass
-        class AddressUser(_User):
-            pass
-        ordermapper = mapper(_Order, orders)
-        usermapper = mapper(_User, users, 
-            properties = {
-                'orders' : relation(ordermapper, lazy=True)
-            })
-        amapper = mapper(AddressUser, addresses, inherits = usermapper)
-            
-        sess = create_session()
-
-        def go():
-            l = sess.query(AddressUser).options(lazyload('orders')).select()
-            # this would fail because the "orders" lazyloader gets created against AddressUsers selectable
-            # and not _User's.
-            assert len(l[0].orders) == 3
-        self.assert_sql_count(db, go, 2)
-        
             
     
 class DeferredTest(MapperSuperTest):