]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- merge of trunk r6544
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 8 Dec 2009 03:10:59 +0000 (03:10 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 8 Dec 2009 03:10:59 +0000 (03:10 +0000)
- Session.execute() now locates table- and
mapper-specific binds based on a passed
in expression which is an insert()/update()/delete()
construct. [ticket:1054]

CHANGES
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/sql/util.py
test/orm/test_session.py

diff --git a/CHANGES b/CHANGES
index b97828ff05c0607a394a8603e824b46ec71beed9..525a9112f26f6aaa65da1acdea06fff99a4ce8de 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -22,7 +22,12 @@ CHANGES
       various unserializable options like those generated
       by contains_eager() out of individual instance states.
       [ticket:1553]
-
+    
+    - Session.execute() now locates table- and 
+      mapper-specific binds based on a passed
+      in expression which is an insert()/update()/delete() 
+      construct. [ticket:1054]
+      
     - Fixed a needless select which would occur when merging
       transient objects that contained a null primary key
       identifier.  [ticket:1618]
index 3dea2f6a3ddcf8b51a44d36ddcf54e71502dd53a..be0821e3b848333b9d186f59339719fc4a48e44c 100644 (file)
@@ -570,13 +570,11 @@ class Session(object):
         self._mapper_flush_opts = {}
 
         if binds is not None:
-            for mapperortable, value in binds.iteritems():
-                if isinstance(mapperortable, type):
-                    mapperortable = _class_mapper(mapperortable).base_mapper
-                self.__binds[mapperortable] = value
-                if isinstance(mapperortable, Mapper):
-                    for t in mapperortable._all_tables:
-                        self.__binds[t] = value
+            for mapperortable, bind in binds.iteritems():
+                if isinstance(mapperortable, (type, Mapper)):
+                    self.bind_mapper(mapperortable, bind)
+                else:
+                    self.bind_table(mapperortable, bind)
 
         if not self.autocommit:
             self.begin()
@@ -857,7 +855,7 @@ class Session(object):
                     "a binding.")
 
         c_mapper = mapper is not None and _class_to_mapper(mapper) or None
-
+        
         # manually bound?
         if self.__binds:
             if c_mapper:
@@ -866,7 +864,7 @@ class Session(object):
                 elif c_mapper.mapped_table in self.__binds:
                     return self.__binds[c_mapper.mapped_table]
             if clause:
-                for t in sql_util.find_tables(clause):
+                for t in sql_util.find_tables(clause, include_crud=True):
                     if t in self.__binds:
                         return self.__binds[t]
 
index 9be405e2192b47c0abb44e0c09b5eada3773b201..dccd3d4627f325c2fdc22a56ce419c67296d4809 100644 (file)
@@ -47,7 +47,9 @@ def find_join_source(clauses, join_to):
         return None, None
 
     
-def find_tables(clause, check_columns=False, include_aliases=False, include_joins=False, include_selects=False):
+def find_tables(clause, check_columns=False, 
+                include_aliases=False, include_joins=False, 
+                include_selects=False, include_crud=False):
     """locate Table objects within the given expression."""
     
     tables = []
@@ -61,7 +63,11 @@ def find_tables(clause, check_columns=False, include_aliases=False, include_join
         
     if include_aliases:
         _visitors['alias']  = tables.append
-
+    
+    if include_crud:
+        _visitors['insert'] = _visitors['update'] = \
+                    _visitors['delete'] = lambda ent: tables.append(ent.table)
+        
     if check_columns:
         def visit_column(column):
             tables.append(column.table)
index 89923081a359ceedf12ff697e39cf2b309a2ba96..828dd1316fca65d2e2c090d342c22cb8d5b9f50a 100644 (file)
@@ -105,49 +105,74 @@ class SessionTest(_fixtures.FixtureTest):
 
     @engines.close_open_connections
     @testing.resolve_artifact_names
-    def test_table_binds_from_expression(self):
-        """Session can extract Table objects from ClauseElements and match them to tables."""
+    def test_mapped_binds(self):
 
-        mapper(Address, addresses)
-        mapper(User, users, properties={
+        # ensure tables are unbound
+        m2 = sa.MetaData()
+        users_unbound =users.tometadata(m2)
+        addresses_unbound = addresses.tometadata(m2)
+
+        mapper(Address, addresses_unbound)
+        mapper(User, users_unbound, properties={
             'addresses':relation(Address,
                                  backref=backref("user", cascade="all"),
                                  cascade="all")})
 
-        Session = sessionmaker(binds={users: self.metadata.bind,
-                                      addresses: self.metadata.bind})
+        Session = sessionmaker(binds={User: self.metadata.bind,
+                                      Address: self.metadata.bind})
         sess = Session()
 
-        sess.execute(users.insert(), params=dict(id=1, name='ed'))
-        eq_(sess.execute(users.select(users.c.id == 1)).fetchall(),
-            [(1, 'ed')])
+        u1 = User(id=1, name='ed')
+        sess.add(u1)
+        eq_(sess.query(User).filter(User.id==1).all(),
+            [User(id=1, name='ed')])
+
+        # test expression binding
+        sess.execute(users_unbound.insert(), params=dict(id=2, name='jack'))
+        eq_(sess.execute(users_unbound.select(users_unbound.c.id == 2)).fetchall(),
+            [(2, 'jack')])
 
-        eq_(sess.execute(users.select(User.id == 1)).fetchall(),
-            [(1, 'ed')])
+        eq_(sess.execute(users_unbound.select(User.id == 2)).fetchall(),
+            [(2, 'jack')])
 
+        sess.execute(users_unbound.delete())
+        eq_(sess.execute(users_unbound.select()).fetchall(), [])
+        
         sess.close()
 
     @engines.close_open_connections
     @testing.resolve_artifact_names
-    def test_mapped_binds_from_expression(self):
-        """Session can extract Table objects from ClauseElements and match them to tables."""
+    def test_table_binds(self):
 
-        mapper(Address, addresses)
-        mapper(User, users, properties={
+        # ensure tables are unbound
+        m2 = sa.MetaData()
+        users_unbound =users.tometadata(m2)
+        addresses_unbound = addresses.tometadata(m2)
+
+        mapper(Address, addresses_unbound)
+        mapper(User, users_unbound, properties={
             'addresses':relation(Address,
                                  backref=backref("user", cascade="all"),
                                  cascade="all")})
 
-        Session = sessionmaker(binds={User: self.metadata.bind,
-                                      Address: self.metadata.bind})
+        Session = sessionmaker(binds={users_unbound: self.metadata.bind,
+                                      addresses_unbound: self.metadata.bind})
         sess = Session()
 
-        sess.execute(users.insert(), params=dict(id=1, name='ed'))
-        eq_(sess.execute(users.select(users.c.id == 1)).fetchall(),
-            [(1, 'ed')])
+        u1 = User(id=1, name='ed')
+        sess.add(u1)
+        eq_(sess.query(User).filter(User.id==1).all(),
+            [User(id=1, name='ed')])
+
+        sess.execute(users_unbound.insert(), params=dict(id=2, name='jack'))
+        eq_(sess.execute(users_unbound.select(users_unbound.c.id == 2)).fetchall(),
+            [(2, 'jack')])
+
+        eq_(sess.execute(users_unbound.select(User.id == 2)).fetchall(),
+            [(2, 'jack')])
 
-        eq_(sess.execute(users.select(User.id == 1)).fetchall(),
-            [(1, 'ed')])
+        sess.execute(users_unbound.delete())
+        eq_(sess.execute(users_unbound.select()).fetchall(), [])
 
         sess.close()