]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 28 Aug 2005 20:44:41 +0000 (20:44 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 28 Aug 2005 20:44:41 +0000 (20:44 +0000)
lib/sqlalchemy/engine.py
lib/sqlalchemy/mapper.py
test/mapper.py
test/query.py

index 7689c53e82571a8504e2acd48a8c4c94db92bab1..9e662f455a3bf86e3d947b35a9af2ecd6795b9c7 100644 (file)
@@ -132,7 +132,7 @@ class SQLEngine(schema.SchemaEngine):
 
         if echo is True or self._echo:
             self.log(statement)
-            self.log("here are the params: " + repr(parameters))
+            self.log(repr(parameters))
 
         if connection is None:
             poolconn = self.connection()
@@ -162,6 +162,14 @@ class ResultProxy:
                 self.props[i] = i
                 i+=1
 
+    def fetchall(self):
+        l = []
+        while True:
+            v = self.fetchone()
+            if v is None:
+                return l
+            l.append(v)
+            
     def fetchone(self):
         row = self.cursor.fetchone()
         if row is not None:
@@ -174,5 +182,7 @@ class RowProxy:
     def __init__(self, parent, row):
         self.parent = parent
         self.row = row
+    def __repr__(self):
+        return repr(self.row)
     def __getitem__(self, key):
         return self.row[self.parent.props[key]]
index c5d2e8c193d3c693337d75e2bf3962c4e60eed15..283bfb97bcd3c44396af1a6dec687922aeb912a2 100644 (file)
@@ -73,10 +73,17 @@ def lazyload(name):
 class Mapper(object):
     def __init__(self, class_, selectable, table = None, properties = None, identitymap = None, use_smart_properties = True, isroot = True, echo = None):
         self.class_ = class_
-        self.selectable = selectable
         self.use_smart_properties = use_smart_properties
+
+        self.selectable = selectable
+        tf = Mapper.TableFinder()
+        self.selectable.accept_visitor(tf)
+        self.tables = tf.tables
+
         if table is None:
-            self.table = self._find_table(selectable)
+            if len(self.tables) > 1:
+                raise "Selectable contains multiple tables - specify primary table argument to Mapper"
+            self.table = self.tables[0]
         else:
             self.table = table
 
@@ -141,7 +148,7 @@ class Mapper(object):
         except KeyError:
             clause = sql.and_()
             i = 0
-            for primary_key in self.selectable.primary_keys:
+            for primary_key in self.table.primary_keys:
                 # appending to the and_'s clause list directly to skip
                 # typechecks etc.
                 clause.clauses.append(primary_key == ident[i])
@@ -190,7 +197,7 @@ class Mapper(object):
         else:
             return self._select_whereclause(arg, **params)
         
-    def save(self, object, traverse = True, refetch = False):
+    def save(self, obj, traverse = True, refetch = False):
         """saves the object.  based on the existence of its primary key, either inserts or updates.
         primary key is determined by the underlying database engine's sequence methodology.
         traverse indicates attached objects should be saved as well.
@@ -199,31 +206,44 @@ class Mapper(object):
         of the attribute, determines if the item is saved.  if smart attributes are not being 
         used, the item is saved unconditionally.
         """
-        if getattr(object, 'dirty', True):
-            pass
-            # do the save
+        # TODO: support multi-table saves
+        if getattr(obj, 'dirty', True):
+            for table in self.tables:
+                for col in table.columns:
+                    if getattr(obj, col.key, None) is None:
+                        self.insert(obj, table)
+                        break
+                else:
+                    self.update(obj, table)
+
         for prop in self.props.values():
-            prop.save(object, traverse, refetch)
-    
-    def remove(self, object, traverse = True):
+            prop.save(obj, traverse, refetch)
+
+    def remove(self, obj, traverse = True):
         """removes the object.  traverse indicates attached objects should be removed as well."""
         pass
-    
-    def insert(self, obj):
-        """inserts the object into its table, regardless of primary key being set.  this is a 
+
+    def insert(self, obj, table = None):
+        """inserts an object into one table, regardless of primary key being set.  this is a 
         lower-level operation than save."""
+
+        if table is None:
+            table = self.table
+
         params = {}
-        for col in self.table.columns:
+        for col in table.columns:
             params[col.key] = getattr(obj, col.key, None)
-        ins = self.table.insert()
+        ins = table.insert()
+        ins.echo = self.echo
         ins.execute(**params)
 
-        # TODO: unset dirty flag
+        # unset dirty flag
+        obj.dirty = False
 
         # populate new primary keys
-        primary_keys = self.table.engine.last_inserted_ids()
+        primary_keys = table.engine.last_inserted_ids()
         index = 0
-        for pk in self.table.primary_keys:
+        for pk in table.primary_keys:
             newid = primary_keys[index]
             index += 1
             # TODO: do this via the ColumnProperty objects
@@ -231,15 +251,24 @@ class Mapper(object):
 
         self.put(obj)
 
-    def update(self, obj):
-        """inserts the object into its table, regardless of primary key being set.  this is a 
+    def update(self, obj, table = None):
+        """updates an object in one table, regardless of primary key being set.  this is a 
         lower-level operation than save."""
+
+        if table is None:
+            table = self.table
         params = {}
-        for col in self.table.columns:
-            params[col.key] = getattr(obj, col.key)
-        upd = self.table.update()
+        clause = sql.and_()
+        for col in table.columns:
+            if col.primary_key:
+                clause.clauses.append(col == getattr(obj, col.key))
+            else:
+                params[col.key] = getattr(obj, col.key)
+        upd = table.update(clause)
+        upd.echo = self.echo
         upd.execute(**params)
-        # TODO: unset dirty flag
+        # unset dirty flag
+        obj.dirty = False
 
     def delete(self, obj):
         """deletes the object's row from its table unconditionally. this is a lower-level
@@ -251,15 +280,11 @@ class Mapper(object):
         pass
 
     class TableFinder(sql.ClauseVisitor):
+        def __init__(self):
+            self.tables = []
         def visit_table(self, table):
-            if hasattr(self, 'table'):
-                raise "Mapper can only create object instances against a single-table identity - specify the 'table' argument to the Mapper constructor"
-            self.table = table
-            
-    def _find_table(self, selectable):
-        tf = Mapper.TableFinder()
-        selectable.accept_visitor(tf)
-        return tf.table
+            self.tables.append(table)
+
 
     def _compile(self, whereclause = None, **options):
         statement = sql.select([self.selectable], whereclause)
@@ -267,7 +292,7 @@ class Mapper(object):
             value.setup(key, self.selectable, statement, **options) 
         statement.use_labels = True
         return statement
-        
+
     def _select_whereclause(self, whereclause = None, **params):
         statement = self._compile(whereclause)
         return self._select_statement(statement, **params)
@@ -280,7 +305,6 @@ class Mapper(object):
     def _identity_key(self, row):
         return self.identitymap.get_key(row, self.class_, self.table, self.selectable)
 
-
     def _instance(self, row, localmap, result):
         """pulls an object instance from the given row and appends it to the given result list.
         if the instance already exists in the given identity map, its not added.  in either
@@ -293,7 +317,7 @@ class Mapper(object):
         exists = self.identitymap.has_key(identitykey)
         if not exists:
             instance = self.class_()
-            for column in self.selectable.primary_keys:
+            for column in self.table.primary_keys:
                 if row[column.label] is None:
                     return None
             self.identitymap[identitykey] = instance
@@ -308,7 +332,6 @@ class Mapper(object):
             imap = localmap[id(result)]
         except KeyError:
             imap = localmap.setdefault(id(result), IdentityMap())
-        
         isduplicate = imap.has_key(identitykey)
         if not isduplicate:
             imap[identitykey] = instance
@@ -325,7 +348,6 @@ class MapperOption:
     of it.  This is used to assist in the prototype pattern used by mapper.options()."""
     def process(self, mapper):
         raise NotImplementedError()
-    
     def hash_key(self):
         return repr(self)
 
@@ -386,6 +408,8 @@ class ColumnProperty(MapperProperty):
 
 
 class PropertyLoader(MapperProperty):
+    """describes an object property that holds a list of items that correspond to a related
+    database table."""
     def __init__(self, mapper, secondary, primaryjoin, secondaryjoin):
         self.mapper = mapper
         self.target = self.mapper.selectable
@@ -484,6 +508,7 @@ class LazyLoadInstance(object):
         return self.mapper.select(self.lazywhere, **self.params)
 
 class EagerLoader(PropertyLoader):
+    """loads related objects inline with a parent query."""
     def init(self, key, parent, root):
         PropertyLoader.init(self, key, parent, root)
         self.to_alias = util.Set()
@@ -504,22 +529,22 @@ class EagerLoader(PropertyLoader):
                 aliasizer = Aliasizer(target, "aliased_" + target.name + "_" + hex(random.randint(0, 65535))[2:])
                 statement.whereclause.accept_visitor(aliasizer)
                 statement.append_from(aliasizer.alias)
-        
+
         if hasattr(statement, '_outerjoin'):
             towrap = statement._outerjoin
         else:
             towrap = primarytable
-        
+
         if self.secondaryjoin is not None:
             statement._outerjoin = sql.outerjoin(sql.outerjoin(towrap, self.secondary, self.secondaryjoin), self.target, self.primaryjoin)
         else:
             statement._outerjoin = sql.outerjoin(towrap, self.target, self.primaryjoin)
-            
+
         statement.append_from(statement._outerjoin)
         statement.append_column(self.target)
         for key, value in self.mapper.props.iteritems():
             value.setup(key, self.mapper.selectable, statement)
-        
+
     def execute(self, instance, row, identitykey, localmap, isduplicate):
         """receive a row.  tell our mapper to look for a new object instance in the row, and attach
         it to a list on the parent instance."""
@@ -561,15 +586,13 @@ class Aliasizer(sql.ClauseVisitor):
         if isinstance(binary.right, schema.Column) and binary.right.table == self.table:
             binary.right = self.alias.c[binary.right.name]
 
-    
 class LazyRow(MapperProperty):
+    """TODO: this will lazy-load additional properties of an object from a secondary table."""
     def __init__(self, table, whereclause, **options):
         self.table = table
         self.whereclause = whereclause
-
     def init(self, key, parent, root):
         self.keys.append(key)
-
     def execute(self, instance, row, identitykey, localmap, isduplicate):
         pass
 
@@ -599,9 +622,9 @@ class IdentityMap(dict):
     def get_id_key(self, ident, class_, table, selectable):
         return (class_, table, tuple(ident))
     def get_instance_key(self, object, class_, table, selectable):
-        return (class_, table, tuple([getattr(object, column.key, None) for column in selectable.primary_keys]))
+        return (class_, table, tuple([getattr(object, column.key, None) for column in table.primary_keys]))
     def get_key(self, row, class_, table, selectable):
-        return (class_, table, tuple([row[column.label] for column in selectable.primary_keys]))
+        return (class_, table, tuple([row[column.label] for column in table.primary_keys]))
     def hash_key(self):
         return "IdentityMap(%s)" % id(self)
 
index c44d8dcfc5b8c6b4a86f2ae0cec6293daade4086..81dcb8e53bdcb08fee3226ed80801b67daa49dfc 100644 (file)
@@ -10,12 +10,13 @@ class User(object):
 objid: %d
 User ID: %s
 User Name: %s
+email address ?: %s
 Addresses: %s
 Orders: %s
 Open Orders %s
 Closed Orderss %s
 ------------------
-""" % tuple([id(self), self.user_id, repr(self.user_name)] + [repr(getattr(self, attr, None)) for attr in ('addresses', 'orders', 'orders_open', 'orders_closed')])
+""" % tuple([id(self), self.user_id, repr(self.user_name), repr(getattr(self, 'email_address', None))] + [repr(getattr(self, attr, None)) for attr in ('addresses', 'orders', 'orders_open', 'orders_closed')])
 )
 
 class Address(object):
@@ -69,6 +70,12 @@ class MapperTest(AssertMixin):
         l = m.select(users.c.user_name.endswith('ed'))
         self.assert_result(l, User, {'user_id' : 8}, {'user_id' : 9})
 
+    def testmultitable(self):
+        usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id)
+        m = mapper(User, usersaddresses, table = users)
+        l = m.select()
+        print repr(l)
+
     def testeageroptions(self):
         """tests that a lazy relation can be upgraded to an eager relation via the options method"""
         m = mapper(User, users, properties = dict(
@@ -103,7 +110,7 @@ class LazyTest(AssertMixin):
             addresses = relation(Address, addresses, lazy = True)
         ), echo = True)
         l = m.select(users.c.user_id == 7)
-        self.assert_result(l, User, 
+        self.assert_result(l, User,
             {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])},
             )
 
@@ -117,10 +124,10 @@ class LazyTest(AssertMixin):
         l = m.select()
         self.assert_result(l, Item, 
             {'item_id' : 1, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 4}, {'keyword_id' : 6}])},
-            {'item_id' : 3, 'keywords' : (Keyword, [{'keyword_id' : 3}, {'keyword_id' : 4}, {'keyword_id' : 6}])},
             {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 5}, {'keyword_id' : 7}])},
-            {'item_id' : 5, 'keywords' : (Keyword, [])},
-            {'item_id' : 4, 'keywords' : (Keyword, [])}
+            {'item_id' : 3, 'keywords' : (Keyword, [{'keyword_id' : 3}, {'keyword_id' : 4}, {'keyword_id' : 6}])},
+            {'item_id' : 4, 'keywords' : (Keyword, [])},
+            {'item_id' : 5, 'keywords' : (Keyword, [])}
         )
 
         l = m.select(and_(keywords.c.name == 'red', keywords.c.keyword_id == itemkeywords.c.keyword_id, items.c.item_id==itemkeywords.c.item_id))
@@ -230,9 +237,61 @@ class SaveTest(PersistTest):
         u.user_name = 'inserttester'
         m = mapper(User, users, echo=True)
         m.insert(u)
+#        nu = m.get(u.user_id)
+        nu = m.select(users.c.user_id == u.user_id)[0]
+        self.assert_(u is nu)
+
+    def testsave(self):
+        # save two users
+        u = User()
+        u.user_name = 'savetester'
+        u2 = User()
+        u2.user_name = 'savetester2'
+        m = mapper(User, users, echo=True)
+        m.save(u)
+        m.save(u2)
+        
+        # assert the first one retreives the same from the identity map
         nu = m.get(u.user_id)
-    #    nu = m.select(users.c.user_id == u.user_id)[0]
         self.assert_(u is nu)
+        
+        # clear out the identity map, so next get forces a SELECT
+        m.identitymap.clear()
+
+        # check it again, identity should be different but ids the same
+        nu = m.get(u.user_id)
+        self.assert_(u is not nu and u.user_id == nu.user_id and nu.user_name == 'savetester')
+        
+        # change first users name and save
+        u.user_name = 'modifiedname'
+        m.save(u)
 
+        # select both
+        userlist = m.select(users.c.user_id.in_(u.user_id, u2.user_id))
+        # making a slight assumption here about the IN clause mechanics with regards to ordering
+        self.assert_(u.user_id == userlist[0].user_id and userlist[0].user_name == 'modifiedname')
+        self.assert_(u2.user_id == userlist[1].user_id and userlist[1].user_name == 'savetester2')
+
+    def testsavemultitable(self):
+        usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id)
+        m = mapper(User, usersaddresses, table = users)
+        u = User()
+        u.user_name = 'multitester'
+        u.email_address = 'multi@test.org'
+        m.save(u)
+        
+        usertable = engine.ResultProxy(users.select().execute()).fetchall()
+        print repr(usertable)
+        addresstable = engine.ResultProxy(addresses.select().execute()).fetchall()
+        print repr(addresstable)
+        
+        u.email_address = 'lala@hey.com'
+        u.user_name = 'imnew'
+        m.save(u)
+        usertable = engine.ResultProxy(users.select().execute()).fetchall()
+        print repr(usertable)
+        addresstable = engine.ResultProxy(addresses.select().execute()).fetchall()
+        print repr(addresstable)
+        
 if __name__ == "__main__":
     unittest.main()        
index af01d3191fbca84735d06a8cfb05c77b68f62900..c92ae70e9ea1f748fcb7eb66624a71bd3ed1775b 100644 (file)
@@ -3,7 +3,7 @@ import unittest, sys
 
 import sqlalchemy.databases.sqlite as sqllite
 
-db = sqllite.engine('querytest.db', echo = True)
+db = sqllite.engine(':memory:', {}, echo = True)
 
 from sqlalchemy.sql import *
 from sqlalchemy.schema import *