def __init__(self, parent, mapper, key):
self.mapper = mapper
+ self.base_mapper = self
self.class_ = mapper.class_
self._inheriting_mappers = []
def primary_mapper(self):
return self
- def base_mapper(self):
- return self
self.__stack.append(key)
def push_mapper(self, mapper):
- self.__stack.append(mapper.base_mapper())
+ self.__stack.append(mapper.base_mapper)
def pop(self):
self.__stack.pop()
self.polymorphic_map = self.inherits.polymorphic_map
self.batch = self.inherits.batch
self.inherits._inheriting_mappers.add(self)
+ self.base_mapper = self.inherits.base_mapper
+ self._all_tables = self.inherits._all_tables
else:
+ self._all_tables = util.Set()
+ self.base_mapper = self
self._synchronizer = None
self.mapped_table = self.local_table
if self.polymorphic_identity is not None:
# go through all of our represented tables
# and assemble primary key columns
for t in self.tables + [self.mapped_table]:
+ self._all_tables.add(t)
try:
l = self.pks_by_table[t]
except KeyError:
result[binary.right] = util.Set([binary.left])
vis = mapperutil.BinaryVisitor(visit_binary)
- for mapper in self.base_mapper().polymorphic_iterator():
+ for mapper in self.base_mapper.polymorphic_iterator():
if mapper.inherit_condition is not None:
vis.traverse(mapper.inherit_condition)
if self.entity_name is None:
self.class_.c = self.c
- def base_mapper(self):
- """Return the ultimate base mapper in an inheritance chain."""
-
- # TODO: calculate this at mapper setup time
- if self.inherits is not None:
- return self.inherits.base_mapper()
- else:
- return self
-
def common_parent(self, other):
"""Return true if the given mapper shares a common inherited parent as this mapper."""
- return self.base_mapper() is other.base_mapper()
+ return self.base_mapper is other.base_mapper
def isa(self, other):
"""Return True if the given mapper inherits from this mapper."""
all their inheriting mappers as well.
To iterate through an entire hierarchy, use
- ``mapper.base_mapper().polymorphic_iterator()``."""
+ ``mapper.base_mapper.polymorphic_iterator()``."""
yield self
for mapper in self._inheriting_mappers:
updated_objects = util.Set()
table_to_mapper = {}
- for mapper in self.base_mapper().polymorphic_iterator():
+ for mapper in self.base_mapper.polymorphic_iterator():
for t in mapper.tables:
table_to_mapper.setdefault(t, mapper)
deleted_objects = util.Set()
table_to_mapper = {}
- for mapper in self.base_mapper().polymorphic_iterator():
+ for mapper in self.base_mapper.polymorphic_iterator():
for t in mapper.tables:
table_to_mapper.setdefault(t, mapper)
from sqlalchemy.orm import unitofwork, query, util as mapperutil
from sqlalchemy.orm.mapper import object_mapper as _object_mapper
from sqlalchemy.orm.mapper import class_mapper as _class_mapper
+from sqlalchemy.orm.mapper import Mapper
__all__ = ['Session', 'SessionTransaction', 'SessionExtension']
self.autoflush = autoflush
self.nested = nested
- def connection(self, mapper_or_class, entity_name=None, **kwargs):
- if isinstance(mapper_or_class, type):
- mapper_or_class = _class_mapper(mapper_or_class, entity_name=entity_name)
- engine = self.session.get_bind(mapper_or_class, **kwargs)
- return self.get_or_add(engine)
+ def connection(self, bindkey, **kwargs):
+ return self.session.connection(bindkey, **kwargs)
def _begin(self, **kwargs):
return SessionTransaction(self.session, self, **kwargs)
self._mapper_flush_opts = {}
if binds is not None:
- for mapperortable, value in binds:
+ for mapperortable, value in binds.iteritems():
if isinstance(mapperortable, type):
- mapperortable = _class_mapper(mapperortable)
+ mapperortable = _class_mapper(mapperortable).base_mapper
self.__binds[mapperortable] = value
+ if isinstance(mapperortable, Mapper):
+ for t in mapperortable._all_tables:
+ self.__binds[t] = value
if self.transactional:
self.begin()
to multiple engines or connections, or is not bound to any connectable.
"""
+ return self.__connection(self.get_bind(mapper))
+
+ def __connection(self, engine, **kwargs):
if self.transaction is not None:
- return self.transaction.connection(mapper)
+ return self.transaction.get_or_add(engine)
else:
- return self.get_bind(mapper).contextual_connect(**kwargs)
-
+ return engine.contextual_connect(**kwargs)
+
def execute(self, clause, params=None, mapper=None, **kwargs):
"""Using the given mapper to identify the appropriate ``Engine``
or ``Connection`` to be used for statement execution, execute the
then the ``ResultProxy`` 's ``close()`` method will release the
resources of the underlying ``Connection``.
"""
- return self.connection(mapper, close_with_result=True).execute(clause, params or {}, **kwargs)
+
+ engine = self.get_bind(mapper, clause=clause)
+
+ return self.__connection(engine, close_with_result=True).execute(clause, params or {}, **kwargs)
def scalar(self, clause, params=None, mapper=None, **kwargs):
"""Like execute() but return a scalar result."""
- return self.connection(mapper, close_with_result=True).scalar(clause, params or {}, **kwargs)
+ engine = self.get_bind(mapper, clause=clause)
+
+ return self.__connection(engine, close_with_result=True).scalar(clause, params or {}, **kwargs)
def close(self):
"""Close this Session.
if isinstance(mapper, type):
mapper = _class_mapper(mapper, entity_name=entity_name)
- self.__binds[mapper] = bind
+ self.__binds[mapper.base_mapper] = bind
+ for t in mapper._all_tables:
+ self.__binds[t] = bind
def bind_table(self, table, bind):
"""Bind the given `table` to the given ``Engine`` or ``Connection``.
self.__binds[table] = bind
- def get_bind(self, mapper):
- """Return the ``Engine`` or ``Connection`` which is used to execute
- statements on behalf of the given `mapper`.
-
- Calling ``connect()`` on the return result will always result
- in a ``Connection`` object. This method disregards any
- ``SessionTransaction`` that may be in progress.
+ def get_bind(self, mapper, clause=None):
- The order of searching is as follows:
-
- 1. if an ``Engine`` or ``Connection`` was bound to this ``Mapper``
- specifically within this ``Session``, return that ``Engine`` or
- ``Connection``.
-
- 2. if an ``Engine`` or ``Connection`` was bound to this `mapper` 's
- underlying ``Table`` within this ``Session`` (i.e. not to the ``Table``
- directly), return that ``Engine`` or ``Connection``.
-
- 3. if an ``Engine`` or ``Connection`` was bound to this ``Session``,
- return that ``Engine`` or ``Connection``.
-
- 4. finally, return the ``Engine`` which was bound directly to the
- ``Table`` 's ``MetaData`` object.
-
- If no ``Engine`` is bound to the ``Table``, an exception is raised.
- """
-
- if mapper is None:
+ if mapper is None and clause is None:
if self.bind is not None:
return self.bind
else:
raise exceptions.InvalidRequestError("This session is unbound to any Engine or Connection; specify a mapper to get_bind()")
- elif self.__binds.has_key(mapper):
- return self.__binds[mapper]
- elif self.__binds.has_key(mapper.compile().mapped_table):
- return self.__binds[mapper.mapped_table]
- elif self.bind is not None:
+
+ elif len(self.__binds):
+ if mapper is not None:
+ if isinstance(mapper, type):
+ mapper = _class_mapper(mapper)
+ if self.__binds.has_key(mapper.base_mapper):
+ return self.__binds[mapper.base_mapper]
+ elif self.__binds.has_key(mapper.compile().mapped_table):
+ return self.__binds[mapper.mapped_table]
+ if clause is not None:
+ for t in clause._table_iterator():
+ if t in self.__binds:
+ return self.__binds[t]
+
+ if self.bind is not None:
return self.bind
else:
+ if isinstance(mapper, type):
+ mapper = _class_mapper(mapper)
e = mapper.mapped_table.bind
if e is None:
raise exceptions.InvalidRequestError("Could not locate any Engine or Connection bound to mapper '%s'" % str(mapper))
# row-loading phase to match up AliasedClause objects with the current
# LoaderStack position.
if parentclauses:
- path = parentclauses.path + (self.parent.base_mapper(), self.key)
+ path = parentclauses.path + (self.parent.base_mapper, self.key)
else:
- path = (self.parent.base_mapper(), self.key)
+ path = (self.parent.base_mapper, self.key)
if self.join_depth:
if len(path) / 2 > self.join_depth:
if dontcreate:
return None
- base_mapper = mapper.base_mapper()
+ base_mapper = mapper.base_mapper
if base_mapper in self.tasks:
base_task = self.tasks[base_mapper]
else:
# also convert to the "base mapper", the parentmost task at the top of an inheritance chain
# dependency sorting is done via non-inheriting mappers only, dependencies between mappers
# in the same inheritance chain is done at the per-object level
- mapper = mapper.primary_mapper().base_mapper()
- dependency = dependency.primary_mapper().base_mapper()
+ mapper = mapper.primary_mapper().base_mapper
+ dependency = dependency.primary_mapper().base_mapper
self.dependencies.add((mapper, dependency))
return l
def dependency_in_cycles(dep):
- proctask = trans.get_task_by_mapper(dep.processor.mapper.base_mapper(), True)
- targettask = trans.get_task_by_mapper(dep.targettask.mapper.base_mapper(), True)
+ proctask = trans.get_task_by_mapper(dep.processor.mapper.base_mapper, True)
+ targettask = trans.get_task_by_mapper(dep.targettask.mapper.base_mapper, True)
return targettask in cycles and (proctask is not None and proctask in cycles)
# organize all original UOWDependencyProcessors by their target task
meth = getattr(self, "visit_%s" % obj.__visit_name__, None)
if meth:
return meth(obj, **kwargs)
-
+
+ def iterate(self, obj, stop_on=None):
+ stack = [obj]
+ traversal = []
+ while len(stack) > 0:
+ t = stack.pop()
+ if stop_on is None or t not in stop_on:
+ yield t
+ traversal.insert(0, t)
+ for c in t.get_children(**self.__traverse_options__):
+ stack.append(c)
+
def traverse(self, obj, stop_on=None, clone=False):
if clone:
obj = obj._clone()
-
- v = self
- visitors = []
- while v is not None:
- visitors.append(v)
- v = getattr(v, '_next', None)
-
- def _trav(obj):
- if stop_on is not None and obj in stop_on:
- return
- if clone:
- obj._copy_internals()
- for c in obj.get_children(**self.__traverse_options__):
- _trav(c)
-
- for v in visitors:
- meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
+
+ stack = [obj]
+ traversal = []
+ while len(stack) > 0:
+ t = stack.pop()
+ if stop_on is None or t not in stop_on:
+ traversal.insert(0, t)
+ if clone:
+ t._copy_internals()
+ for c in t.get_children(**self.__traverse_options__):
+ stack.append(c)
+ for target in traversal:
+ v = self
+ while v is not None:
+ meth = getattr(v, "visit_%s" % target.__visit_name__, None)
if meth:
- meth(obj)
- _trav(obj)
+ meth(target)
+ v = getattr(v, '_next', None)
return obj
-
+
def chain(self, visitor):
"""'chain' an additional ClauseVisitor onto this ClauseVisitor.
def supports_execution(self):
return True
+ def _table_iterator(self):
+ return iter([])
+
class _Null(ColumnElement):
"""Represent the NULL keyword in a SQL statement.
def supports_execution(self):
return self.original.supports_execution()
+ def _table_iterator(self):
+ return self.original._table_iterator()
+
def _locate_oid_column(self):
if self.selectable.oid_column is not None:
return self.selectable.oid_column._make_proxy(self)
def get_children(self, column_collections=True, **kwargs):
return (column_collections and list(self.c) or []) + \
[self._order_by_clause, self._group_by_clause] + list(self.selects)
+
+ def _table_iterator(self):
+ for s in self.selects:
+ for t in s._table_iterator():
+ yield t
def _find_engine(self):
for s in self.selects:
def intersect_all(self, other, **kwargs):
return intersect_all(self, other, **kwargs)
+ def _table_iterator(self):
+ for t in NoColumnVisitor().iterate(self):
+ if isinstance(t, TableClause):
+ yield t
+
def _find_engine(self):
"""Try to return a Engine, either explicitly set in this
object, or searched within the from clauses for one.
def supports_execution(self):
return True
+ def _table_iterator(self):
+ return iter([self.table])
+
def _process_colparams(self, parameters):
"""Receive the *values* of an ``INSERT`` or ``UPDATE``
statement and construct appropriate bind parameters.
# then see if expunge fails
session.expunge(u)
+ def test_binds_from_expression(self):
+ """test that Session can extract Table objects from ClauseElements and match them to tables."""
+ Session = sessionmaker(binds={users:testbase.db, addresses:testbase.db})
+ sess = Session()
+ sess.execute(users.insert(), params=dict(user_id=1, user_name='ed'))
+ assert sess.execute(users.select()).fetchall() == [(1, 'ed')]
+
+ mapper(Address, addresses)
+ mapper(User, users, properties={
+ 'addresses':relation(Address, backref=backref("user", cascade="all"), cascade="all")
+ })
+ Session = sessionmaker(binds={User:testbase.db, Address:testbase.db})
+ sess.execute(users.insert(), params=dict(user_id=2, user_name='fred'))
+ assert sess.execute(users.select()).fetchall() == [(1, 'ed'), (2, 'fred')]
+
+
@testing.unsupported('sqlite', 'mssql') # TEMP: test causes mssql to hang
def test_transaction(self):
class User(object):pass
import unittest, re, sys, os, operator
from cStringIO import StringIO
import testlib.config as config
-sql, MetaData, clear_mappers = None, None, None
+sql, MetaData, clear_mappers, Session = None, None, None, None
__all__ = ('PersistTest', 'AssertMixin', 'ORMTest', 'SQLCompileTest')
_otest_metadata.drop_all()
def tearDown(self):
+ global Session
+ if Session is None:
+ from sqlalchemy.orm.session import Session
+ Session.close_all()
global clear_mappers
if clear_mappers is None:
from sqlalchemy.orm import clear_mappers