]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 29 Aug 2005 00:07:30 +0000 (00:07 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 29 Aug 2005 00:07:30 +0000 (00:07 +0000)
lib/sqlalchemy/mapper.py
test/mapper.py

index bef7474f24f9bc1bc59e32476edec3c61150ba5d..62a1bb8567b0ba96e35f67aa7debdcee1412f23c 100644 (file)
@@ -96,7 +96,12 @@ class Mapper(object):
 
         self.props = {}
         for column in self.selectable.columns:
-            self.props[column.key] = ColumnProperty(column)
+            prop = self.props.get(column.key, None)
+            if prop is None:
+                prop = ColumnProperty(column)
+                self.props[column.key] = prop
+            else:
+                prop.columns.append(column)
         self.properties = properties
         if properties is not None:
             for key, value in properties.iteritems():
@@ -209,18 +214,45 @@ class Mapper(object):
         """
 
         if getattr(obj, 'dirty', True):
-            f = def():
+            def foo():
+                props = {}
+                for prop in self.props.values():
+                    if not isinstance(prop, ColumnProperty):
+                        continue
+                    for col in prop.columns:
+                        props[col] = prop
                 for table in self.tables:
-                    for col in table.columns:
-                        if getattr(obj, col.key, None) is None:
-                            self.insert(obj, table)
+                    params = {}
+                    for primary_key in table.primary_keys:
+                        if props[primary_key].getattr(obj) is None:
+                            statement = table.insert()
+                            for col in table.columns:
+                                params[col.key] = props[col].getattr(obj)
                             break
                     else:
-                        self.update(obj, table)
-
+                        clause = sql.and_()
+                        for col in table.columns:
+                            if col.primary_key:
+                                clause.clauses.append(col == props[col].getattr(obj))
+                            else:
+                                params[col.key] = props[col].getattr(obj)
+                        statement = table.update(clause)
+                    statement.echo = self.echo
+                    statement.execute(**params)
+                    if isinstance(statement, sql.Insert):
+                        primary_keys = table.engine.last_inserted_ids()
+                        index = 0
+                        for col in table.primary_keys:
+                            newid = primary_keys[index]
+                            index += 1
+                            props[col].setattr(obj, newid)
+                        self.put(obj)
+                # unset dirty flag
+                obj.dirty = False
                 for prop in self.props.values():
-                    prop.save(obj, traverse, refetch)
-            self.transaction(f)
+                    if not isinstance(prop, ColumnProperty):
+                        prop.save(obj, traverse, refetch)
+            self.transaction(foo)
         else:
             for prop in self.props.values():
                 prop.save(obj, traverse, refetch)
@@ -232,52 +264,6 @@ class Mapper(object):
         """removes the object.  traverse indicates attached objects should be removed as well."""
         pass
 
-    def insert(self, obj, table = None):
-        """inserts an object into one table, regardless of primary key being set.  this is a 
-        lower-level operation than save."""
-
-        if table is None:
-            table = self.table
-
-        params = {}
-        for col in table.columns:
-            params[col.key] = getattr(obj, col.key, None)
-        ins = table.insert()
-        ins.echo = self.echo
-        ins.execute(**params)
-
-        # unset dirty flag
-        obj.dirty = False
-
-        # populate new primary keys
-        primary_keys = table.engine.last_inserted_ids()
-        index = 0
-        for pk in table.primary_keys:
-            newid = primary_keys[index]
-            index += 1
-            # TODO: do this via the ColumnProperty objects
-            setattr(obj, pk.key, newid)
-
-        self.put(obj)
-
-    def update(self, obj, table = None):
-        """updates an object in one table, regardless of primary key being set.  this is a 
-        lower-level operation than save."""
-
-        if table is None:
-            table = self.table
-        params = {}
-        clause = sql.and_()
-        for col in table.columns:
-            if col.primary_key:
-                clause.clauses.append(col == getattr(obj, col.key))
-            else:
-                params[col.key] = getattr(obj, col.key)
-        upd = table.update(clause)
-        upd.echo = self.echo
-        upd.execute(**params)
-        # unset dirty flag
-        obj.dirty = False
 
     def delete(self, obj):
         """deletes the object's row from its table unconditionally. this is a lower-level
@@ -392,11 +378,15 @@ class MapperProperty:
 
 class ColumnProperty(MapperProperty):
     """describes an object attribute that corresponds to a table column."""
-    def __init__(self, column):
-        self.column = column
+    def __init__(self, *columns):
+        self.columns = list(columns)
 
+    def getattr(self, object):
+        return getattr(object, self.key, None)
+    def setattr(self, object, value):
+        setattr(object, self.key, value)
     def hash_key(self):
-        return "ColumnProperty(%s)" % hash_key(self.column)
+        return "ColumnProperty(%s)" % repr([hash_key(c) for c in self.columns])
 
     def init(self, key, parent, root):
         self.key = key
@@ -450,19 +440,20 @@ class PropertyLoader(MapperProperty):
         # if a mapping table exists, determine the two foreign key columns 
         # in the mapping table, set the two values, and insert that row, for
         # each row in the list
-        for child in getattr(obj, self.key):
-            setter = ForeignKeySetter(obj, child)
-            self.primaryjoin.accept_visitor(setter)
-            self.mapper.save(child)
+  #      for child in getattr(obj, self.key):
+  #          setter = ForeignKeySetter(obj, child)
+  #          self.primaryjoin.accept_visitor(setter)
+  #          self.mapper.save(child)
+        pass
 
     def delete(self):
         self.mapper.delete()
 
-class ForeignKeySetter(ClauseVisitor):
-    def visit_binary(self, binary):
-        if binary.operator == '==':
-            if binary.left.table == self.primarytable and binary.right.table == self.secondarytable:
-                setattr(self.child, binary.left.colname, getattr(obj, binary.right.colname))
+#class ForeignKeySetter(ClauseVisitor):
#   def visit_binary(self, binary):
+  #      if binary.operator == '==':
+  #          if binary.left.table == self.primarytable and binary.right.table == self.secondarytable:
+  #              setattr(self.child, binary.left.colname, getattr(obj, binary.right.colname))
 
 class LazyLoader(PropertyLoader):
 
index d1f29e43893e48c053856358de25060176591563..e8ed519f6fc5fe25cd22f750a13b4c70176da1d9 100644 (file)
@@ -232,14 +232,6 @@ class EagerTest(PersistTest):
 
 class SaveTest(PersistTest):
         
-    def testinsert(self):
-        u = User()
-        u.user_name = 'inserttester'
-        m = mapper(User, users, echo=True)
-        m.insert(u)
-#        nu = m.get(u.user_id)
-        nu = m.select(users.c.user_id == u.user_id)[0]
-        self.assert_(u is nu)
 
     def testsave(self):
         # save two users
@@ -274,24 +266,24 @@ class SaveTest(PersistTest):
 
     def testsavemultitable(self):
         usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id)
-        m = mapper(User, usersaddresses, table = users)
+        m = mapper(User, usersaddresses, table = users, echo = True, properties = dict(email = ColumnProperty(addresses.c.email_address), foo_id = ColumnProperty(users.c.user_id, addresses.c.user_id)))
         u = User()
         u.user_name = 'multitester'
-        u.email_address = 'multi@test.org'
+        u.email = 'multi@test.org'
         m.save(u)
 
-        usertable = engine.ResultProxy(users.select().execute()).fetchall()
-        print repr(usertable)
-        addresstable = engine.ResultProxy(addresses.select().execute()).fetchall()
-        print repr(addresstable)
+        usertable = engine.ResultProxy(users.select(users.c.user_id.in_(10)).execute()).fetchall()
+        self.assert_(usertable[0].row == (10, 'multitester'))
+        addresstable = engine.ResultProxy(addresses.select(addresses.c.address_id.in_(4)).execute()).fetchall()
+        self.assert_(addresstable[0].row == (4, 10, 'multi@test.org'))
 
-        u.email_address = 'lala@hey.com'
+        u.email = 'lala@hey.com'
         u.user_name = 'imnew'
         m.save(u)
-        usertable = engine.ResultProxy(users.select().execute()).fetchall()
-        print repr(usertable)
-        addresstable = engine.ResultProxy(addresses.select().execute()).fetchall()
-        print repr(addresstable)
+        usertable = engine.ResultProxy(users.select(users.c.user_id.in_(10)).execute()).fetchall()
+        self.assert_(usertable[0].row == (10, 'imnew'))
+        addresstable = engine.ResultProxy(addresses.select(addresses.c.address_id.in_(4)).execute()).fetchall()
+        self.assert_(addresstable[0].row == (4, 10, 'lala@hey.com'))
 
 if __name__ == "__main__":
     unittest.main()