]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 13 Sep 2005 06:08:01 +0000 (06:08 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 13 Sep 2005 06:08:01 +0000 (06:08 +0000)
lib/sqlalchemy/mapper.py
lib/sqlalchemy/objectstore.py

index 7c593b600d421388dfd40f4862c7e87b1f6a36d4..03d5190589642d9f0427689d658e1338bf105d06 100644 (file)
@@ -257,7 +257,7 @@ class Mapper(object):
     def _compile(self, whereclause = None, **options):
         statement = sql.select([self.selectable], whereclause)
         for key, value in self.props.iteritems():
-            value.setup(key, self.selectable, statement, **options) 
+            value.setup(key, statement, **options) 
         statement.use_labels = True
         return statement
 
@@ -329,7 +329,7 @@ class MapperProperty:
         a process."""
         raise NotImplementedError()
 
-    def setup(self, key, primarytable, statement, **options):
+    def setup(self, key, statement, **options):
         """called when a statement is being constructed.  """
         return self
 
@@ -538,16 +538,17 @@ class LazyLoader(PropertyLoader):
             if not issubclass(parent.class_, object):
                 raise "LazyLoader can only be used with new-style classes"
             setattr(parent.class_, key, SmartProperty(key).property())
-
-    def setup(self, key, primarytable, statement, **options):
         if self.secondaryjoin is not None:
             self.lazywhere = sql.and_(self.primaryjoin, self.secondaryjoin)
         else:
             self.lazywhere = self.primaryjoin
+
+        # we dont want to screw with the primaryjoin and secondary join of the PropertyLoader,
+        # so create a copy
         self.lazywhere = self.lazywhere.copy_container()
-        li = LazyIzer(primarytable)
+        self.binds = {}
+        li = BinaryVisitor(lambda b: self._create_lazy_clause(b, self.binds))
         self.lazywhere.accept_visitor(li)
-        self.binds = li.binds
 
     def execute(self, instance, row, identitykey, isduplicate):
         if not isduplicate:
@@ -556,6 +557,15 @@ class LazyLoader(PropertyLoader):
             # when u deserialize tho
             objectstore.uow().attribute_set_callable(instance, self.key, LazyLoadInstance(self, row))
 
+    def _create_lazy_clause(self, binary, binds):
+        if isinstance(binary.left, schema.Column) and binary.left.table == self.parent.selectable:
+            binary.left = binds.setdefault(self.parent.selectable.name + "_" + binary.left.name,
+                    sql.BindParamClause(self.parent.selectable.name + "_" + binary.left.name, None, shortname = binary.left.name))
+
+        if isinstance(binary.right, schema.Column) and binary.right.table == self.parent.selectable:
+            binary.right = binds.setdefault(self.parent.selectable.name + "_" + binary.right.name,
+                    sql.BindParamClause(self.parent.selectable.name + "_" + binary.right.name, None, shortname = binary.right.name))
+
 class LazyLoadInstance(object):
     """attached to a specific object instance to load related rows."""
     def __init__(self, lazyloader, row):
@@ -589,12 +599,13 @@ class EagerLoader(PropertyLoader):
             [self.to_alias.append(f) for f in self.secondaryjoin._get_from_objects()]
         del self.to_alias[parent.selectable]
 
-    def setup(self, key, primarytable, statement, **options):
+    def setup(self, key, statement, **options):
         """add a left outer join to the statement thats being constructed"""
 
         if statement.whereclause is not None:
             # "aliasize" the tables referenced in the user-defined whereclause to not 
             # collide with the tables used by the eager load
+            # note that we arent affecting the mapper's selectable, nor our own primary or secondary joins
             aliasizer = Aliasizer(*self.to_alias)
             statement.whereclause.accept_visitor(aliasizer)
             for alias in aliasizer.aliases.values():
@@ -603,7 +614,7 @@ class EagerLoader(PropertyLoader):
         if hasattr(statement, '_outerjoin'):
             towrap = statement._outerjoin
         else:
-            towrap = primarytable
+            towrap = self.parent.selectable
 
         if self.secondaryjoin is not None:
             statement._outerjoin = sql.outerjoin(sql.outerjoin(towrap, self.secondary, self.secondaryjoin), self.target, self.primaryjoin)
@@ -613,7 +624,7 @@ class EagerLoader(PropertyLoader):
         statement.append_from(statement._outerjoin)
         statement.append_column(self.target)
         for key, value in self.mapper.props.iteritems():
-            value.setup(key, self.mapper.selectable, statement)
+            value.setup(key, statement)
 
     def execute(self, instance, row, identitykey, isduplicate):
         """receive a row.  tell our mapper to look for a new object instance in the row, and attach
@@ -685,27 +696,6 @@ class BinaryVisitor(sql.ClauseVisitor):
     def visit_binary(self, binary):
         self.func(binary)
         
-class LazyIzer(sql.ClauseVisitor):
-    """converts an expression which refers to a table column into an
-    expression refers to a Bind Param, i.e. a specific value.  
-    e.g. the clause 'WHERE tablea.foo=tableb.foo' becomes 'WHERE tablea.foo=:foo'.  
-    this is used to turn a join expression into one useable by a lazy load
-    for a specific parent row."""
-
-    def __init__(self, table):
-        self.table = table
-        self.binds = {}
-
-    def visit_binary(self, binary):
-        if isinstance(binary.left, schema.Column) and binary.left.table == self.table:
-            binary.left = self.binds.setdefault(self.table.name + "_" + binary.left.name,
-                    sql.BindParamClause(self.table.name + "_" + binary.left.name, None, shortname = binary.left.name))
-
-        if isinstance(binary.right, schema.Column) and binary.right.table == self.table:
-            binary.right = self.binds.setdefault(self.table.name + "_" + binary.right.name,
-                    sql.BindParamClause(self.table.name + "_" + binary.right.name, None, shortname = binary.right.name))
-
-
 class SmartProperty(object):
     def __init__(self, key):
         self.key = key
index d27f5588eb65e440691076c234ca1282727cad3c..3057f7ef0d2fff8cea520bcc44f3fa72aab1a99c 100644 (file)
@@ -123,13 +123,11 @@ class UOWListElement(util.HistoryArraySet):
         res = util.HistoryArraySet._setrecord(self, item)
         if res:
             uow().modified_lists.append(self.listpointer)
-            #uow().register_dirty(self.obj())
         return res
     def _delrecord(self, item):
         res = util.HistoryArraySet._delrecord(self, item)
         if res:
             uow().modified_lists.append(self.listpointer)
-            #uow().register_dirty(self.obj())
         return res
     
 class UnitOfWork(object):
@@ -220,27 +218,20 @@ class UnitOfWork(object):
         self.dependencies = {}
         self.tasks = {}
         
-        
         for obj in [n for n in self.new] + [d for d in self.dirty]:
             mapper = sqlalchemy.mapper.object_mapper(obj)
-            try:
-                task = self.tasks[mapper]
-            except KeyError:
-                task = self.tasks.setdefault(mapper, UOWTask(mapper))
+            task = self.get_task_by_mapper(mapper)
             task.objects.append(obj)
 
         for item in self.modified_lists:
             item = item.list
             obj = item.obj()
             mapper = sqlalchemy.mapper.object_mapper(obj)
-            try:
-                task = self.tasks[mapper]
-            except KeyError:
-                task = self.tasks.setdefault(mapper, UOWTask(mapper))
+            task = self.get_task_by_mapper(mapper)
             task.lists.append(obj)
             
         for task in self.tasks.values():
-            task.mapper.register_dependencies(task.objects + task.lists, self)
+            task.mapper.register_dependencies(util.HashSet(task.objects + task.lists), self)
             
         mapperlist = self.tasks.values()
         def compare(a, b):
@@ -252,10 +243,7 @@ class UnitOfWork(object):
                 return 0
         mapperlist.sort(compare)
         
-        # TODO: figure some way to process dependencies without saving a lead item,
-        # for the case when a list changes within a many-to-many
-        # also break save_obj into a list of tasks that are more SQL-specific
-        # generally, make this whole thing more straightforward and generic-'task' oriented
+        # TODO: break save_obj into a list of tasks that are more SQL-specific
         for task in mapperlist:
             obj_list = task.objects
             for obj in obj_list:
@@ -270,19 +258,24 @@ class UnitOfWork(object):
             item = item.list
             item.clear_history()
         self.modified_lists.clear()
-        
+
+        self.tasks.clear()
+        self.dependencies.clear()
         # TODO: deleted stuff
 
+    # TODO: better interface for tasks with no object save, or multiple dependencies
     def register_dependency(self, mapper, dependency, processor, stuff_to_process):
         self.dependencies[(mapper, dependency)] = True
-        try:
-            task = self.tasks[mapper]
-        except KeyError:
-            task = self.tasks.setdefault(mapper, UOWTask(mapper))
+        task = self.get_task_by_mapper(mapper)
         if processor is not None:
             task.dependencies.append((processor, stuff_to_process))
         
-
+    def get_task_by_mapper(self, mapper):
+        try:
+            return self.tasks[mapper]
+        except KeyError:
+            return self.tasks.setdefault(mapper, UOWTask(mapper))
+    
 class UOWTask(object):
     def __init__(self, mapper):
         self.mapper = mapper