]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
added explicit "session" argument to get(), select_whereclause in mapper, as well...
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 6 Apr 2006 19:01:14 +0000 (19:01 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 6 Apr 2006 19:01:14 +0000 (19:01 +0000)
examples/adjacencytree/byroot_tree.py
examples/polymorph/polymorph2.py
lib/sqlalchemy/mapping/mapper.py
lib/sqlalchemy/mapping/properties.py
lib/sqlalchemy/util.py
test/mapper.py

index f564b1f9ef4212ba98dee5c3ddabb0cd167c6c0c..ece90e8d5145820475ce2f68fe72b73c4f413775 100644 (file)
@@ -45,6 +45,7 @@ class NodeList(util.OrderedDict):
     def __iter__(self):
         return iter(self.values())
 
+
 class TreeNode(object):
     """a hierarchical Tree class, which adds the concept of a "root node".  The root is 
     the topmost node in a tree, or in other words a node whose parent ID is NULL.  
index 8f93fcdfa4a3441bada85d2d2a0d0d885ff2c5e2..351a06eca1cf7097767951c3223b842560a7c2d4 100644 (file)
@@ -86,12 +86,12 @@ class PersonLoader(MapperExtension):
         else:
             return Person()
             
-    def populate_instance(self, mapper, instance, row, identitykey, imap, isnew):
+    def populate_instance(self, mapper, session, instance, row, identitykey, imap, isnew):
         if row[person_join.c.type] =='engineer':
-            Engineer.mapper.populate_instance(instance, row, identitykey, imap, isnew, frommapper=mapper)
+            Engineer.mapper.populate_instance(session, instance, row, identitykey, imap, isnew, frommapper=mapper)
             return False
         elif row[person_join.c.type] =='manager':
-            Manager.mapper.populate_instance(instance, row, identitykey, imap, isnew, frommapper=mapper)
+            Manager.mapper.populate_instance(session, instance, row, identitykey, imap, isnew, frommapper=mapper)
             return False
         else:
             return sqlalchemy.mapping.EXT_PASS
index 2db6e715a6fd7ef12dee28394c4edd35ed2dff79..e82aaeb12e7bedd21db4350f2eca79253e666819 100644 (file)
@@ -301,6 +301,9 @@ class Mapper(object):
     def instances(self, cursor, *mappers, **kwargs):
         limit = kwargs.get('limit', None)
         offset = kwargs.get('offset', None)
+        session = kwargs.get('session', None)
+        if session is None:
+            session = objectstore.get_session()
         populate_existing = kwargs.get('populate_existing', False)
         
         result = util.HistoryArraySet()
@@ -314,32 +317,34 @@ class Mapper(object):
             row = cursor.fetchone()
             if row is None:
                 break
-            self._instance(row, imap, result, populate_existing=populate_existing)
+            self._instance(session, row, imap, result, populate_existing=populate_existing)
             i = 0
             for m in mappers:
-                m._instance(row, imap, otherresults[i])
+                m._instance(session, row, imap, otherresults[i])
                 i+=1
                 
         # store new stuff in the identity map
         for value in imap.values():
-            objectstore.get_session().register_clean(value)
+            session.register_clean(value)
 
         if mappers:
             result = [result] + otherresults
         return result
             
-    def get(self, *ident):
+    def get(self, *ident, **kwargs):
         """returns an instance of the object based on the given identifier, or None
         if not found.  The *ident argument is a 
         list of primary key columns in the order of the table def's primary key columns."""
         key = objectstore.get_id_key(ident, self.class_, self.entity_name)
         #print "key: " + repr(key) + " ident: " + repr(ident)
-        return self._get(key, ident)
+        return self._get(key, ident, **kwargs)
         
-    def _get(self, key, ident=None, reload=False):
+    def _get(self, key, ident=None, reload=False, session=None):
         if not reload and not self.always_refresh:
             try:
-                return objectstore.get_session()._get(key)
+                if session is None:
+                    session = objectstore.get_session()
+                return session._get(key)
             except KeyError:
                 pass
             
@@ -352,7 +357,7 @@ class Mapper(object):
             i += 1
         try:
             statement = self._compile(self._get_clause)
-            return self._select_statement(statement, params=params, populate_existing=reload)[0]
+            return self._select_statement(statement, params=params, populate_existing=reload, session=session)[0]
         except IndexError:
             return None
 
@@ -552,9 +557,9 @@ class Mapper(object):
         else:
             return self.select_whereclause(whereclause=arg, **kwargs)
 
-    def select_whereclause(self, whereclause=None, params=None, **kwargs):
+    def select_whereclause(self, whereclause=None, params=None, session=None, **kwargs):
         statement = self._compile(whereclause, **kwargs)
-        return self._select_statement(statement, params=params)
+        return self._select_statement(statement, params=params, session=session)
 
     def count(self, whereclause=None, params=None, **kwargs):
         s = self.table.count(whereclause)
@@ -856,7 +861,7 @@ class Mapper(object):
     def _identity_key(self, row):
         return objectstore.get_row_key(row, self.class_, self.pks_by_table[self.table], self.entity_name)
 
-    def _instance(self, row, imap, result = None, populate_existing = False):
+    def _instance(self, session, row, imap, result = None, populate_existing = False):
         """pulls an object instance from the given row and appends it to the given result
         list. if the instance already exists in the given identity map, its not added.  in
         either case, executes all the property loaders on the instance to also process extra
@@ -865,18 +870,20 @@ class Mapper(object):
         # including modifying any of its related items lists, as its already
         # been exposed to being modified by the application.
         
+        if session is None:
+            session = objectstore.get_session()
+            
         populate_existing = populate_existing or self.always_refresh
         identitykey = self._identity_key(row)
-        sess = objectstore.get_session()
-        if sess.has_key(identitykey):
-            instance = sess._get(identitykey)
+        if session.has_key(identitykey):
+            instance = session._get(identitykey)
 
             isnew = False
-            if populate_existing or sess.is_expired(instance, unexpire=True):
+            if populate_existing or session.is_expired(instance, unexpire=True):
                 if not imap.has_key(identitykey):
                     imap[identitykey] = instance
                 for prop in self.props.values():
-                    prop.execute(instance, row, identitykey, imap, True)
+                    prop.execute(session, instance, row, identitykey, imap, True)
             if self.extension.append_result(self, row, imap, result, instance, isnew, populate_existing=populate_existing) is EXT_PASS:
                 if result is not None:
                     result.append_nohistory(instance)
@@ -893,7 +900,7 @@ class Mapper(object):
             # plugin point
             instance = self.extension.create_instance(self, row, imap, self.class_)
             if instance is EXT_PASS:
-                instance = self.class_(_mapper_nohistory=True, _sa_entity_name=self.entity_name)
+                instance = self.class_(_mapper_nohistory=True, _sa_entity_name=self.entity_name, _sa_session=session)
             imap[identitykey] = instance
             isnew = True
         else:
@@ -904,8 +911,8 @@ class Mapper(object):
         
         # call further mapper properties on the row, to pull further 
         # instances from the row and possibly populate this item.
-        if self.extension.populate_instance(self, instance, row, identitykey, imap, isnew) is EXT_PASS:
-            self.populate_instance(instance, row, identitykey, imap, isnew)
+        if self.extension.populate_instance(self, session, instance, row, identitykey, imap, isnew) is EXT_PASS:
+            self.populate_instance(session, instance, row, identitykey, imap, isnew)
         if self.extension.append_result(self, row, imap, result, instance, isnew, populate_existing=populate_existing) is EXT_PASS:
             if result is not None:
                 result.append_nohistory(instance)
@@ -923,17 +930,17 @@ class Mapper(object):
             newrow[c] = newrow[c.key]
         return newrow
         
-    def populate_instance(self, instance, row, identitykey, imap, isnew, frommapper=None):
+    def populate_instance(self, session, instance, row, identitykey, imap, isnew, frommapper=None):
         if frommapper is not None:
             row = frommapper.translate_row(self, row)
             
         for prop in self.props.values():
-            prop.execute(instance, row, identitykey, imap, isnew)
+            prop.execute(session, instance, row, identitykey, imap, isnew)
         
 class MapperProperty(object):
     """an element attached to a Mapper that describes and assists in the loading and saving 
     of an attribute on an object instance."""
-    def execute(self, instance, row, identitykey, imap, isnew):
+    def execute(self, session, instance, row, identitykey, imap, isnew):
         """called when the mapper receives a row.  instance is the parent instance
         corresponding to the row. """
         raise NotImplementedError()
@@ -1054,7 +1061,7 @@ class MapperExtension(object):
             return EXT_PASS
         else:
             return self.next.append_result(mapper, row, imap, result, instance, isnew, populate_existing)
-    def populate_instance(self, mapper, instance, row, identitykey, imap, isnew):
+    def populate_instance(self, mapper, session, instance, row, identitykey, imap, isnew):
         """called right before the mapper, after creating an instance from a row, passes the row
         to its MapperProperty objects which are responsible for populating the object's attributes.
         If this method returns True, it is assumed that the mapper should do the appending, else
@@ -1062,14 +1069,14 @@ class MapperExtension(object):
         
         Essentially, this method is used to have a different mapper populate the object:
         
-            def populate_instance(self, mapper, instance, row, identitykey, imap, isnew):
-                othermapper.populate_instance(instance, row, identitykey, imap, isnew, frommapper=mapper)
+            def populate_instance(self, mapper, session, instance, row, identitykey, imap, isnew):
+                othermapper.populate_instance(session, instance, row, identitykey, imap, isnew, frommapper=mapper)
                 return True
         """
         if self.next is None:
             return EXT_PASS
         else:
-            return self.next.populate_instance(mapper, instance, row, identitykey, imap, isnew)
+            return self.next.populate_instance(mapper, session, instance, row, identitykey, imap, isnew)
     def before_insert(self, mapper, instance):
         """called before an object instance is INSERTed into its table.
         
index 14e4749099268d81f7c3c7c20deb0fa846371694..592e4dc0ac5afb5ba392c8a90a1301c5a6d07e85 100644 (file)
@@ -46,7 +46,7 @@ class ColumnProperty(MapperProperty):
         if parent._is_primary_mapper():
             #print "regiser col on class %s key %s" % (parent.class_.__name__, key)
             objectstore.uow().register_attribute(parent.class_, key, uselist = False)
-    def execute(self, instance, row, identitykey, imap, isnew):
+    def execute(self, session, instance, row, identitykey, imap, isnew):
         if isnew:
             #print "POPULATING OBJ", instance.__class__.__name__, "COL", self.columns[0]._label, "WITH DATA", row[self.columns[0]], "ROW IS A", row.__class__.__name__, "COL ID", id(self.columns[0])
             instance.__dict__[self.key] = row[self.columns[0]]
@@ -95,7 +95,7 @@ class DeferredColumnProperty(ColumnProperty):
         return lazyload
     def setup(self, key, statement, **options):
         pass
-    def execute(self, instance, row, identitykey, imap, isnew):
+    def execute(self, session, instance, row, identitykey, imap, isnew):
         if isnew:
             if not self.is_primary():
                 objectstore.global_attributes.create_history(instance, self.key, False, callable_=self.setup_loader(instance))
@@ -533,7 +533,7 @@ class PropertyLoader(MapperProperty):
                             self._synchronize(obj, child, None, True)
                             uowcommit.register_object(child, isdelete=self.private)
 
-    def execute(self, instance, row, identitykey, imap, isnew):
+    def execute(self, session, instance, row, identitykey, imap, isnew):
         if self.is_primary():
             return
         #print "PLAIN PROPLOADER EXEC NON-PRIAMRY", repr(id(self)), repr(self.mapper.class_), self.key
@@ -595,6 +595,7 @@ class LazyLoader(PropertyLoader):
         def lazyload():
             params = {}
             allparams = True
+            session = objectstore.get_session(instance)
             #print "setting up loader, lazywhere", str(self.lazywhere)
             for col, bind in self.lazybinds.iteritems():
                 params[bind.key] = self.parent._getattrbycolumn(instance, col)
@@ -608,14 +609,14 @@ class LazyLoader(PropertyLoader):
                     ident = []
                     for primary_key in self.mapper.pks_by_table[self.mapper.table]:
                         ident.append(params[primary_key._label])
-                    return self.mapper.get(*ident)
+                    return self.mapper.get(session=session, *ident)
                 elif self.order_by is not False:
                     order_by = self.order_by
                 elif self.secondary is not None and self.secondary.default_order_by() is not None:
                     order_by = self.secondary.default_order_by()
                 else:
                     order_by = False
-                result = self.mapper.select_whereclause(self.lazywhere, order_by=order_by, params=params)
+                result = self.mapper.select_whereclause(self.lazywhere, order_by=order_by, params=params, session=session)
             else:
                 result = []
             if self.uselist:
@@ -627,7 +628,7 @@ class LazyLoader(PropertyLoader):
                     return None
         return lazyload
         
-    def execute(self, instance, row, identitykey, imap, isnew):
+    def execute(self, session, instance, row, identitykey, imap, isnew):
         if isnew:
             # new object instance being loaded from a result row
             if not self.is_primary():
@@ -780,7 +781,7 @@ class EagerLoader(PropertyLoader):
             value.setup(key, statement, eagertable=self.eagertarget)
             
         
-    def execute(self, instance, row, identitykey, imap, isnew):
+    def execute(self, session, instance, row, identitykey, imap, isnew):
         """receive a row.  tell our mapper to look for a new object instance in the row, and attach
         it to a list on the parent instance."""
         
@@ -791,11 +792,11 @@ class EagerLoader(PropertyLoader):
             
         if not self.uselist:
             if isnew:
-                h.setattr_clean(self._instance(row, imap))
+                h.setattr_clean(self._instance(session, row, imap))
             else:
                 # call _instance on the row, even though the object has been created,
                 # so that we further descend into properties
-                self._instance(row, imap)
+                self._instance(session, row, imap)
                 
             return
         elif isnew:
@@ -803,7 +804,7 @@ class EagerLoader(PropertyLoader):
         else:
             result_list = getattr(instance, self.key)
     
-        self._instance(row, imap, result_list)
+        self._instance(session, row, imap, result_list)
 
     def _create_decorator_row(self):
         class DecoratorDict(object):
@@ -823,7 +824,7 @@ class EagerLoader(PropertyLoader):
             map[parent.name] = c
         return DecoratorDict
         
-    def _instance(self, row, imap, result_list=None):
+    def _instance(self, session, row, imap, result_list=None):
         """gets an instance from a row, via this EagerLoader's mapper."""
         # since the EagerLoader makes an Alias of its mapper's table,
         # we translate the actual result columns back to what they 
@@ -833,7 +834,7 @@ class EagerLoader(PropertyLoader):
         # (which is what mappers use) as well as its "label" (which might be what
         # user-defined code is using)
         row = self._row_decorator(row)
-        return self.mapper._instance(row, imap, result_list)
+        return self.mapper._instance(session, row, imap, result_list)
 
 class GenericOption(MapperOption):
     """a mapper option that can handle dotted property names,
index 44209b10494c86b611653b80a2696d5cbb9c8d92..9369a727036881fb0afeb2c8bef347a0e4462837 100644 (file)
@@ -294,6 +294,8 @@ class HistoryArraySet(UserList.UserList):
         else:
             self.data = []
         self.readonly=readonly
+#    def __iter__(self):
+#        return iter([k for k in self.records if self.records[k] is not False])
     def __getattr__(self, attr):
         """proxies unknown HistoryArraySet methods and attributes to the underlying
         data array.  this allows custom list classes to be used."""
index 5c1b268eb019c6bc7c200c8811a73cf79acb4154..41c29cfdfcb0d9f4661f26faede7ae7106682f55 100644 (file)
@@ -124,6 +124,13 @@ class MapperTest(MapperSuperTest):
             objectstore.refresh(u)
         self.assert_sql_count(db, go, 1)
 
+    def testsessionpropigation(self):
+        sess = objectstore.Session()
+        m = mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), lazy=True)})
+        u = m.get(7, session=sess)
+        assert objectstore.get_session(u) is sess
+        assert objectstore.get_session(u.addresses[0]) is sess
+        
     def testexpire(self):
         m = mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), lazy=False)})
         u = m.get(7)