]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- base_mapper() becomes a plain attribute
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 14 Aug 2007 03:19:46 +0000 (03:19 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 14 Aug 2007 03:19:46 +0000 (03:19 +0000)
- session.execute() and scalar() can search for a Table with which to bind
from using the given ClauseElement
- session automatically extrapolates tables from mappers with binds,
also uses base_mapper so that inheritance hierarchies bind automatically
- moved ClauseVisitor traversal back to inlined non-recursive

lib/sqlalchemy/orm/dependency.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/unitofwork.py
lib/sqlalchemy/sql.py
test/orm/session.py
test/testlib/testing.py

index c04771b232003f177df91415b0eed521eca5842c..88c689a87cad52b4cdb66a3245f3c9f22b4268a9 100644 (file)
@@ -420,6 +420,7 @@ class MapperStub(object):
 
     def __init__(self, parent, mapper, key):
         self.mapper = mapper
+        self.base_mapper = self
         self.class_ = mapper.class_
         self._inheriting_mappers = []
 
@@ -438,5 +439,3 @@ class MapperStub(object):
     def primary_mapper(self):
         return self
 
-    def base_mapper(self):
-        return self
index abaeff49c5a906b9bfe8e8070e75838eda5e7137..74b184a7c254023ae2691b18daa3122d43aa2af8 100644 (file)
@@ -462,7 +462,7 @@ class LoaderStack(object):
         self.__stack.append(key)
         
     def push_mapper(self, mapper):
-        self.__stack.append(mapper.base_mapper())
+        self.__stack.append(mapper.base_mapper)
         
     def pop(self):
         self.__stack.pop()
index 85a7f426c012b058650912b4df59255511c2f2f5..9e3cb3aaf246a10d736c2c2776c86e74518fc557 100644 (file)
@@ -368,7 +368,11 @@ class Mapper(object):
             self.polymorphic_map = self.inherits.polymorphic_map
             self.batch = self.inherits.batch
             self.inherits._inheriting_mappers.add(self)
+            self.base_mapper = self.inherits.base_mapper
+            self._all_tables = self.inherits._all_tables
         else:
+            self._all_tables = util.Set()
+            self.base_mapper = self
             self._synchronizer = None
             self.mapped_table = self.local_table
             if self.polymorphic_identity is not None:
@@ -424,6 +428,7 @@ class Mapper(object):
         # go through all of our represented tables
         # and assemble primary key columns
         for t in self.tables + [self.mapped_table]:
+            self._all_tables.add(t)
             try:
                 l = self.pks_by_table[t]
             except KeyError:
@@ -534,7 +539,7 @@ class Mapper(object):
                     result[binary.right] = util.Set([binary.left])
         vis = mapperutil.BinaryVisitor(visit_binary)
 
-        for mapper in self.base_mapper().polymorphic_iterator():
+        for mapper in self.base_mapper.polymorphic_iterator():
             if mapper.inherit_condition is not None:
                 vis.traverse(mapper.inherit_condition)
 
@@ -716,19 +721,10 @@ class Mapper(object):
         if self.entity_name is None:
             self.class_.c = self.c
 
-    def base_mapper(self):
-        """Return the ultimate base mapper in an inheritance chain."""
-
-        # TODO: calculate this at mapper setup time
-        if self.inherits is not None:
-            return self.inherits.base_mapper()
-        else:
-            return self
-
     def common_parent(self, other):
         """Return true if the given mapper shares a common inherited parent as this mapper."""
 
-        return self.base_mapper() is other.base_mapper()
+        return self.base_mapper is other.base_mapper
 
     def isa(self, other):
         """Return True if the given mapper inherits from this mapper."""
@@ -752,7 +748,7 @@ class Mapper(object):
         all their inheriting mappers as well.
 
         To iterate through an entire hierarchy, use
-        ``mapper.base_mapper().polymorphic_iterator()``."""
+        ``mapper.base_mapper.polymorphic_iterator()``."""
 
         yield self
         for mapper in self._inheriting_mappers:
@@ -1033,7 +1029,7 @@ class Mapper(object):
         updated_objects = util.Set()
 
         table_to_mapper = {}
-        for mapper in self.base_mapper().polymorphic_iterator():
+        for mapper in self.base_mapper.polymorphic_iterator():
             for t in mapper.tables:
                 table_to_mapper.setdefault(t, mapper)
 
@@ -1247,7 +1243,7 @@ class Mapper(object):
         
         deleted_objects = util.Set()
         table_to_mapper = {}
-        for mapper in self.base_mapper().polymorphic_iterator():
+        for mapper in self.base_mapper.polymorphic_iterator():
             for t in mapper.tables:
                 table_to_mapper.setdefault(t, mapper)
 
index 1dfd1b665c283c6d45f2677b8104f8bc85256fbb..f8589dbe42445adc52cec2f5772da8e714d2e12d 100644 (file)
@@ -9,6 +9,7 @@ from sqlalchemy import util, exceptions, sql, engine
 from sqlalchemy.orm import unitofwork, query, util as mapperutil
 from sqlalchemy.orm.mapper import object_mapper as _object_mapper
 from sqlalchemy.orm.mapper import class_mapper as _class_mapper
+from sqlalchemy.orm.mapper import Mapper
 
 
 __all__ = ['Session', 'SessionTransaction', 'SessionExtension']
@@ -146,11 +147,8 @@ class SessionTransaction(object):
         self.autoflush = autoflush
         self.nested = nested
 
-    def connection(self, mapper_or_class, entity_name=None, **kwargs):
-        if isinstance(mapper_or_class, type):
-            mapper_or_class = _class_mapper(mapper_or_class, entity_name=entity_name)
-        engine = self.session.get_bind(mapper_or_class, **kwargs)
-        return self.get_or_add(engine)
+    def connection(self, bindkey, **kwargs):
+        return self.session.connection(bindkey, **kwargs)
 
     def _begin(self, **kwargs):
         return SessionTransaction(self.session, self, **kwargs)
@@ -406,10 +404,13 @@ class Session(object):
         self._mapper_flush_opts = {}
         
         if binds is not None:
-            for mapperortable, value in binds:
+            for mapperortable, value in binds.iteritems():
                 if isinstance(mapperortable, type):
-                    mapperortable = _class_mapper(mapperortable)
+                    mapperortable = _class_mapper(mapperortable).base_mapper
                 self.__binds[mapperortable] = value
+                if isinstance(mapperortable, Mapper):
+                    for t in mapperortable._all_tables:
+                        self.__binds[t] = value
                 
         if self.transactional:
             self.begin()
@@ -504,11 +505,14 @@ class Session(object):
         to multiple engines or connections, or is not bound to any connectable.
         """
 
+        return self.__connection(self.get_bind(mapper))
+
+    def __connection(self, engine, **kwargs):
         if self.transaction is not None:
-            return self.transaction.connection(mapper)
+            return self.transaction.get_or_add(engine)
         else:
-            return self.get_bind(mapper).contextual_connect(**kwargs)
-
+            return engine.contextual_connect(**kwargs)
+        
     def execute(self, clause, params=None, mapper=None, **kwargs):
         """Using the given mapper to identify the appropriate ``Engine``
         or ``Connection`` to be used for statement execution, execute the
@@ -520,12 +524,17 @@ class Session(object):
         then the ``ResultProxy`` 's ``close()`` method will release the
         resources of the underlying ``Connection``.
         """
-        return self.connection(mapper, close_with_result=True).execute(clause, params or {}, **kwargs)
+
+        engine = self.get_bind(mapper, clause=clause)
+        
+        return self.__connection(engine, close_with_result=True).execute(clause, params or {}, **kwargs)
 
     def scalar(self, clause, params=None, mapper=None, **kwargs):
         """Like execute() but return a scalar result."""
 
-        return self.connection(mapper, close_with_result=True).scalar(clause, params or {}, **kwargs)
+        engine = self.get_bind(mapper, clause=clause)
+        
+        return self.__connection(engine, close_with_result=True).scalar(clause, params or {}, **kwargs)
 
     def close(self):
         """Close this Session.  
@@ -575,7 +584,9 @@ class Session(object):
         if isinstance(mapper, type):
             mapper = _class_mapper(mapper, entity_name=entity_name)
 
-        self.__binds[mapper] = bind
+        self.__binds[mapper.base_mapper] = bind
+        for t in mapper._all_tables:
+            self.__binds[t] = bind
 
     def bind_table(self, table, bind):
         """Bind the given `table` to the given ``Engine`` or ``Connection``.
@@ -586,45 +597,32 @@ class Session(object):
 
         self.__binds[table] = bind
 
-    def get_bind(self, mapper):
-        """Return the ``Engine`` or ``Connection`` which is used to execute
-        statements on behalf of the given `mapper`.
-
-        Calling ``connect()`` on the return result will always result
-        in a ``Connection`` object.  This method disregards any
-        ``SessionTransaction`` that may be in progress.
+    def get_bind(self, mapper, clause=None):
 
-        The order of searching is as follows:
-
-        1. if an ``Engine`` or ``Connection`` was bound to this ``Mapper``
-           specifically within this ``Session``, return that ``Engine`` or
-           ``Connection``.
-
-        2. if an ``Engine`` or ``Connection`` was bound to this `mapper` 's
-           underlying ``Table`` within this ``Session`` (i.e. not to the ``Table``
-           directly), return that ``Engine`` or ``Connection``.
-
-        3. if an ``Engine`` or ``Connection`` was bound to this ``Session``,
-           return that ``Engine`` or ``Connection``.
-
-        4. finally, return the ``Engine`` which was bound directly to the
-           ``Table`` 's ``MetaData`` object.
-
-        If no ``Engine`` is bound to the ``Table``, an exception is raised.
-        """
-
-        if mapper is None:
+        if mapper is None and clause is None:
             if self.bind is not None:
                 return self.bind
             else:
                 raise exceptions.InvalidRequestError("This session is unbound to any Engine or Connection; specify a mapper to get_bind()")
-        elif self.__binds.has_key(mapper):
-            return self.__binds[mapper]
-        elif self.__binds.has_key(mapper.compile().mapped_table):
-            return self.__binds[mapper.mapped_table]
-        elif self.bind is not None:
+                
+        elif len(self.__binds):
+            if mapper is not None:
+                if isinstance(mapper, type):
+                    mapper = _class_mapper(mapper)
+                if self.__binds.has_key(mapper.base_mapper):
+                    return self.__binds[mapper.base_mapper]
+                elif self.__binds.has_key(mapper.compile().mapped_table):
+                    return self.__binds[mapper.mapped_table]
+            if clause is not None:
+                for t in clause._table_iterator():
+                    if t in self.__binds:
+                        return self.__binds[t]
+                        
+        if self.bind is not None:
             return self.bind
         else:
+            if isinstance(mapper, type):
+                mapper = _class_mapper(mapper)
             e = mapper.mapped_table.bind
             if e is None:
                 raise exceptions.InvalidRequestError("Could not locate any Engine or Connection bound to mapper '%s'" % str(mapper))
index 43b95a0fdc934c04d3bbe356cd9b452abbccc91a..f2bc93d3a94d4c48c22b9d17adb54a4042053c0f 100644 (file)
@@ -463,9 +463,9 @@ class EagerLoader(AbstractRelationLoader):
         # row-loading phase to match up AliasedClause objects with the current
         # LoaderStack position.
         if parentclauses:
-            path = parentclauses.path + (self.parent.base_mapper(), self.key)
+            path = parentclauses.path + (self.parent.base_mapper, self.key)
         else:
-            path = (self.parent.base_mapper(), self.key)
+            path = (self.parent.base_mapper, self.key)
         
         if self.join_depth:
             if len(path) / 2 > self.join_depth:
index c0eebe3b0b13bd83b111628c8a0fc66989bea138..7acb26341f73a0ddb01c5427dc7fbf44897ced7e 100644 (file)
@@ -307,7 +307,7 @@ class UOWTransaction(object):
             if dontcreate:
                 return None
                 
-            base_mapper = mapper.base_mapper()
+            base_mapper = mapper.base_mapper
             if base_mapper in self.tasks:
                 base_task = self.tasks[base_mapper]
             else:
@@ -336,8 +336,8 @@ class UOWTransaction(object):
         # also convert to the "base mapper", the parentmost task at the top of an inheritance chain
         # dependency sorting is done via non-inheriting mappers only, dependencies between mappers
         # in the same inheritance chain is done at the per-object level
-        mapper = mapper.primary_mapper().base_mapper()
-        dependency = dependency.primary_mapper().base_mapper()
+        mapper = mapper.primary_mapper().base_mapper
+        dependency = dependency.primary_mapper().base_mapper
 
         self.dependencies.add((mapper, dependency))
 
@@ -715,8 +715,8 @@ class UOWTask(object):
             return l
 
         def dependency_in_cycles(dep):
-            proctask = trans.get_task_by_mapper(dep.processor.mapper.base_mapper(), True)
-            targettask = trans.get_task_by_mapper(dep.targettask.mapper.base_mapper(), True)
+            proctask = trans.get_task_by_mapper(dep.processor.mapper.base_mapper, True)
+            targettask = trans.get_task_by_mapper(dep.targettask.mapper.base_mapper, True)
             return targettask in cycles and (proctask is not None and proctask in cycles)
 
         # organize all original UOWDependencyProcessors by their target task
index ee64e82f27f7c6d2a7d43b324fe0ff36e3a75f81..3fc13a50dc7e0b416d7934f02bb4c68a2fd2d693 100644 (file)
@@ -894,32 +894,41 @@ class ClauseVisitor(object):
         meth = getattr(self, "visit_%s" % obj.__visit_name__, None)
         if meth:
             return meth(obj, **kwargs)
-            
+
+    def iterate(self, obj, stop_on=None):
+        stack = [obj]
+        traversal = []
+        while len(stack) > 0:
+            t = stack.pop()
+            if stop_on is None or t not in stop_on:
+                yield t
+                traversal.insert(0, t)
+                for c in t.get_children(**self.__traverse_options__):
+                    stack.append(c)
+        
     def traverse(self, obj, stop_on=None, clone=False):
         if clone:
             obj = obj._clone()
-
-        v = self
-        visitors = []
-        while v is not None:
-            visitors.append(v)
-            v = getattr(v, '_next', None)
-
-        def _trav(obj):
-            if stop_on is not None and obj in stop_on:
-                return
-            if clone:
-                obj._copy_internals()
-            for c in obj.get_children(**self.__traverse_options__):
-                _trav(c)
-
-            for v in visitors:
-                meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
+            
+        stack = [obj]
+        traversal = []
+        while len(stack) > 0:
+            t = stack.pop()
+            if stop_on is None or t not in stop_on:
+                traversal.insert(0, t)
+                if clone:
+                    t._copy_internals()
+                for c in t.get_children(**self.__traverse_options__):
+                    stack.append(c)
+        for target in traversal:
+            v = self
+            while v is not None:
+                meth = getattr(v, "visit_%s" % target.__visit_name__, None)
                 if meth:
-                    meth(obj)
-        _trav(obj)
+                    meth(target)
+                v = getattr(v, '_next', None)
         return obj
-        
+
     def chain(self, visitor):
         """'chain' an additional ClauseVisitor onto this ClauseVisitor.
         
@@ -2070,6 +2079,9 @@ class _TextClause(ClauseElement):
     def supports_execution(self):
         return True
 
+    def _table_iterator(self):
+        return iter([])
+
 class _Null(ColumnElement):
     """Represent the NULL keyword in a SQL statement.
 
@@ -2592,6 +2604,9 @@ class Alias(FromClause):
     def supports_execution(self):
         return self.original.supports_execution()
 
+    def _table_iterator(self):
+        return self.original._table_iterator()
+
     def _locate_oid_column(self):
         if self.selectable.oid_column is not None:
             return self.selectable.oid_column._make_proxy(self)
@@ -3065,6 +3080,11 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
     def get_children(self, column_collections=True, **kwargs):
         return (column_collections and list(self.c) or []) + \
             [self._order_by_clause, self._group_by_clause] + list(self.selects)
+
+    def _table_iterator(self):
+        for s in self.selects:
+            for t in s._table_iterator():
+                yield t
             
     def _find_engine(self):
         for s in self.selects:
@@ -3334,6 +3354,11 @@ class Select(_SelectBaseMixin, FromClause):
     def intersect_all(self, other, **kwargs):
         return intersect_all(self, other, **kwargs)
 
+    def _table_iterator(self):
+        for t in NoColumnVisitor().iterate(self):
+            if isinstance(t, TableClause):
+                yield t
+        
     def _find_engine(self):
         """Try to return a Engine, either explicitly set in this
         object, or searched within the from clauses for one.
@@ -3365,6 +3390,9 @@ class _UpdateBase(ClauseElement):
     def supports_execution(self):
         return True
 
+    def _table_iterator(self):
+        return iter([self.table])
+        
     def _process_colparams(self, parameters):
         """Receive the *values* of an ``INSERT`` or ``UPDATE``
         statement and construct appropriate bind parameters.
index 4720593b6e803d279a8fe8e068b5d05bff3256ca..8e12b819d086beb85c106ec4ec69a8172b47aa02 100644 (file)
@@ -74,6 +74,22 @@ class SessionTest(AssertMixin):
         # then see if expunge fails
         session.expunge(u)
     
+    def test_binds_from_expression(self):
+        """test that Session can extract Table objects from ClauseElements and match them to tables."""
+        Session = sessionmaker(binds={users:testbase.db, addresses:testbase.db})
+        sess = Session()
+        sess.execute(users.insert(), params=dict(user_id=1, user_name='ed'))
+        assert sess.execute(users.select()).fetchall() == [(1, 'ed')]
+        
+        mapper(Address, addresses)
+        mapper(User, users, properties={
+            'addresses':relation(Address, backref=backref("user", cascade="all"), cascade="all")
+        })
+        Session = sessionmaker(binds={User:testbase.db, Address:testbase.db})
+        sess.execute(users.insert(), params=dict(user_id=2, user_name='fred'))
+        assert sess.execute(users.select()).fetchall() == [(1, 'ed'), (2, 'fred')]
+        
+        
     @testing.unsupported('sqlite', 'mssql') # TEMP: test causes mssql to hang
     def test_transaction(self):
         class User(object):pass
index 6c1d7ad8105f4d94a9ef9f86c1ac4a098d6ef024..9ee201202156d2b85d90291ced061afe94bb71be 100644 (file)
@@ -6,7 +6,7 @@ import testbase
 import unittest, re, sys, os, operator
 from cStringIO import StringIO
 import testlib.config as config
-sql, MetaData, clear_mappers = None, None, None
+sql, MetaData, clear_mappers, Session = None, None, None, None
 
 
 __all__ = ('PersistTest', 'AssertMixin', 'ORMTest', 'SQLCompileTest')
@@ -323,6 +323,10 @@ class ORMTest(AssertMixin):
         _otest_metadata.drop_all()
 
     def tearDown(self):
+        global Session
+        if Session is None:
+            from sqlalchemy.orm.session import Session
+        Session.close_all()
         global clear_mappers
         if clear_mappers is None:
             from sqlalchemy.orm import clear_mappers