]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
remove a little cruft
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 30 Nov 2007 03:34:06 +0000 (03:34 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 30 Nov 2007 03:34:06 +0000 (03:34 +0000)
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/session.py

index 52523d787677f7e676628688af37a18fbc9b0ef8..5673be44c117233a8d2338f345a515a3dde396ed 100644 (file)
@@ -401,6 +401,11 @@ class Mapper(object):
         if not self.tables:
             raise exceptions.InvalidRequestError("Could not find any Table objects in mapped table '%s'" % str(self.mapped_table))
 
+        # TODO: move the "figure pks" step down into compile_properties; after 
+        # all columns have been mapped, assemble PK columns and their
+        # proxied parents into the pks_by_table collection, then get rid 
+        # of the _has_pks method
+        
         # determine primary key columns
         self.pks_by_table = {}
 
@@ -1004,6 +1009,7 @@ class Mapper(object):
                     continue
                 pks = mapper.pks_by_table[table]
                 instance_key = mapper.identity_key_from_instance(obj)
+
                 if self.__should_log_debug:
                     self.__log_debug("save_obj() table '%s' instance %s identity %s" % (table.name, mapperutil.instance_str(obj), str(instance_key)))
 
@@ -1042,7 +1048,7 @@ class Mapper(object):
                             params[col._label] = mapper.get_attr_by_column(obj, col)
                             params[col.key] = params[col._label] + 1
                             for prop in mapper._columntoproperty.values():
-                                history = prop.get_history(obj, passive=True)
+                                history = attributes.get_history(obj, prop.key, passive=True)
                                 if history and history.added_items():
                                     hasdata = True
                         elif col in pks:
@@ -1055,7 +1061,7 @@ class Mapper(object):
                             prop = mapper._getpropbycolumn(col, False)
                             if prop is None:
                                 continue
-                            history = prop.get_history(obj, passive=True)
+                            history = attributes.get_history(obj, prop.key, passive=True)
                             if history:
                                 a = history.added_items()
                                 if a:
@@ -1065,8 +1071,6 @@ class Mapper(object):
                                         params[col.key] = prop.get_col_value(col, a[0])
                                     hasdata = True
                     if hasdata:
-                        # if none of the attributes changed, dont even
-                        # add the row to be updated.
                         update.append((obj, params, mapper, connection, value_params))
 
             if update:
@@ -1079,8 +1083,9 @@ class Mapper(object):
                 statement = table.update(clause)
                 rows = 0
                 supports_sane_rowcount = True
+                pks = mapper.pks_by_table[table]
                 def comparator(a, b):
-                    for col in mapper.pks_by_table[table]:
+                    for col in pks:
                         x = cmp(a[1][col._label],b[1][col._label])
                         if x != 0:
                             return x
@@ -1119,13 +1124,11 @@ class Mapper(object):
                     # synchronize newly inserted ids from one table to the next
                     # TODO: this fires off more than needed, try to organize syncrules
                     # per table
-                    def sync(mapper):
-                        inherit = mapper.inherits
-                        if inherit is not None:
-                            sync(inherit)
-                        if mapper._synchronizer is not None:
-                            mapper._synchronizer.execute(obj, obj)
-                    sync(mapper)
+                    mappers = list(mapper.iterate_to_root())
+                    mappers.reverse()
+                    for m in mappers:
+                        if m._synchronizer is not None:
+                            m._synchronizer.execute(obj, obj)
 
                     # testlib.pragma exempt:__hash__
                     inserted_objects.add((id(obj), obj, connection))
index 00fa8f9d3d0ce3f75ab10db49aeaccd76b47a752..2c50ec92fca5b3b292fbee7a52363ade6a178072 100644 (file)
@@ -58,9 +58,6 @@ class ColumnProperty(StrategizedProperty):
     def setattr(self, object, value, column):
         setattr(object, self.key, value)
 
-    def get_history(self, obj, passive=False):
-        return attributes.get_history(obj, self.key, passive=passive)
-
     def merge(self, session, source, dest, dont_load, _recursive):
         setattr(dest, self.key, getattr(source, self.key, None))
 
index 6b6d678359c25829317904ddd17dffa903298089..7097273f59c1326dabd3bfe0536e4618456d04bf 100644 (file)
@@ -571,6 +571,8 @@ class Session(object):
         self.uow = unitofwork.UnitOfWork(self)
         self.identity_map = self.uow.identity_map
 
+    # TODO: need much more test coverage for bind_mapper() and similar !
+    
     def bind_mapper(self, mapper, bind, entity_name=None):
         """Bind the given `mapper` or `class` to the given ``Engine`` or ``Connection``.