From: Mike Bayer Date: Tue, 14 Aug 2007 03:19:46 +0000 (+0000) Subject: - base_mapper() becomes a plain attribute X-Git-Tag: rel_0_4beta2~4 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=b9ed823528ef88fb0dc64120e1e501306f4c3768;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - base_mapper() becomes a plain attribute - session.execute() and scalar() can search for a Table with which to bind from using the given ClauseElement - session automatically extrapolates tables from mappers with binds, also uses base_mapper so that inheritance hierarchies bind automatically - moved ClauseVisitor traversal back to inlined non-recursive --- diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index c04771b232..88c689a87c 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -420,6 +420,7 @@ class MapperStub(object): def __init__(self, parent, mapper, key): self.mapper = mapper + self.base_mapper = self self.class_ = mapper.class_ self._inheriting_mappers = [] @@ -438,5 +439,3 @@ class MapperStub(object): def primary_mapper(self): return self - def base_mapper(self): - return self diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index abaeff49c5..74b184a7c2 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -462,7 +462,7 @@ class LoaderStack(object): self.__stack.append(key) def push_mapper(self, mapper): - self.__stack.append(mapper.base_mapper()) + self.__stack.append(mapper.base_mapper) def pop(self): self.__stack.pop() diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 85a7f426c0..9e3cb3aaf2 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -368,7 +368,11 @@ class Mapper(object): self.polymorphic_map = self.inherits.polymorphic_map self.batch = self.inherits.batch self.inherits._inheriting_mappers.add(self) + self.base_mapper = self.inherits.base_mapper + self._all_tables = self.inherits._all_tables else: + self._all_tables = util.Set() + self.base_mapper = self self._synchronizer = None self.mapped_table = self.local_table if self.polymorphic_identity is not None: @@ -424,6 +428,7 @@ class Mapper(object): # go through all of our represented tables # and assemble primary key columns for t in self.tables + [self.mapped_table]: + self._all_tables.add(t) try: l = self.pks_by_table[t] except KeyError: @@ -534,7 +539,7 @@ class Mapper(object): result[binary.right] = util.Set([binary.left]) vis = mapperutil.BinaryVisitor(visit_binary) - for mapper in self.base_mapper().polymorphic_iterator(): + for mapper in self.base_mapper.polymorphic_iterator(): if mapper.inherit_condition is not None: vis.traverse(mapper.inherit_condition) @@ -716,19 +721,10 @@ class Mapper(object): if self.entity_name is None: self.class_.c = self.c - def base_mapper(self): - """Return the ultimate base mapper in an inheritance chain.""" - - # TODO: calculate this at mapper setup time - if self.inherits is not None: - return self.inherits.base_mapper() - else: - return self - def common_parent(self, other): """Return true if the given mapper shares a common inherited parent as this mapper.""" - return self.base_mapper() is other.base_mapper() + return self.base_mapper is other.base_mapper def isa(self, other): """Return True if the given mapper inherits from this mapper.""" @@ -752,7 +748,7 @@ class Mapper(object): all their inheriting mappers as well. To iterate through an entire hierarchy, use - ``mapper.base_mapper().polymorphic_iterator()``.""" + ``mapper.base_mapper.polymorphic_iterator()``.""" yield self for mapper in self._inheriting_mappers: @@ -1033,7 +1029,7 @@ class Mapper(object): updated_objects = util.Set() table_to_mapper = {} - for mapper in self.base_mapper().polymorphic_iterator(): + for mapper in self.base_mapper.polymorphic_iterator(): for t in mapper.tables: table_to_mapper.setdefault(t, mapper) @@ -1247,7 +1243,7 @@ class Mapper(object): deleted_objects = util.Set() table_to_mapper = {} - for mapper in self.base_mapper().polymorphic_iterator(): + for mapper in self.base_mapper.polymorphic_iterator(): for t in mapper.tables: table_to_mapper.setdefault(t, mapper) diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 1dfd1b665c..f8589dbe42 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -9,6 +9,7 @@ from sqlalchemy import util, exceptions, sql, engine from sqlalchemy.orm import unitofwork, query, util as mapperutil from sqlalchemy.orm.mapper import object_mapper as _object_mapper from sqlalchemy.orm.mapper import class_mapper as _class_mapper +from sqlalchemy.orm.mapper import Mapper __all__ = ['Session', 'SessionTransaction', 'SessionExtension'] @@ -146,11 +147,8 @@ class SessionTransaction(object): self.autoflush = autoflush self.nested = nested - def connection(self, mapper_or_class, entity_name=None, **kwargs): - if isinstance(mapper_or_class, type): - mapper_or_class = _class_mapper(mapper_or_class, entity_name=entity_name) - engine = self.session.get_bind(mapper_or_class, **kwargs) - return self.get_or_add(engine) + def connection(self, bindkey, **kwargs): + return self.session.connection(bindkey, **kwargs) def _begin(self, **kwargs): return SessionTransaction(self.session, self, **kwargs) @@ -406,10 +404,13 @@ class Session(object): self._mapper_flush_opts = {} if binds is not None: - for mapperortable, value in binds: + for mapperortable, value in binds.iteritems(): if isinstance(mapperortable, type): - mapperortable = _class_mapper(mapperortable) + mapperortable = _class_mapper(mapperortable).base_mapper self.__binds[mapperortable] = value + if isinstance(mapperortable, Mapper): + for t in mapperortable._all_tables: + self.__binds[t] = value if self.transactional: self.begin() @@ -504,11 +505,14 @@ class Session(object): to multiple engines or connections, or is not bound to any connectable. """ + return self.__connection(self.get_bind(mapper)) + + def __connection(self, engine, **kwargs): if self.transaction is not None: - return self.transaction.connection(mapper) + return self.transaction.get_or_add(engine) else: - return self.get_bind(mapper).contextual_connect(**kwargs) - + return engine.contextual_connect(**kwargs) + def execute(self, clause, params=None, mapper=None, **kwargs): """Using the given mapper to identify the appropriate ``Engine`` or ``Connection`` to be used for statement execution, execute the @@ -520,12 +524,17 @@ class Session(object): then the ``ResultProxy`` 's ``close()`` method will release the resources of the underlying ``Connection``. """ - return self.connection(mapper, close_with_result=True).execute(clause, params or {}, **kwargs) + + engine = self.get_bind(mapper, clause=clause) + + return self.__connection(engine, close_with_result=True).execute(clause, params or {}, **kwargs) def scalar(self, clause, params=None, mapper=None, **kwargs): """Like execute() but return a scalar result.""" - return self.connection(mapper, close_with_result=True).scalar(clause, params or {}, **kwargs) + engine = self.get_bind(mapper, clause=clause) + + return self.__connection(engine, close_with_result=True).scalar(clause, params or {}, **kwargs) def close(self): """Close this Session. @@ -575,7 +584,9 @@ class Session(object): if isinstance(mapper, type): mapper = _class_mapper(mapper, entity_name=entity_name) - self.__binds[mapper] = bind + self.__binds[mapper.base_mapper] = bind + for t in mapper._all_tables: + self.__binds[t] = bind def bind_table(self, table, bind): """Bind the given `table` to the given ``Engine`` or ``Connection``. @@ -586,45 +597,32 @@ class Session(object): self.__binds[table] = bind - def get_bind(self, mapper): - """Return the ``Engine`` or ``Connection`` which is used to execute - statements on behalf of the given `mapper`. - - Calling ``connect()`` on the return result will always result - in a ``Connection`` object. This method disregards any - ``SessionTransaction`` that may be in progress. + def get_bind(self, mapper, clause=None): - The order of searching is as follows: - - 1. if an ``Engine`` or ``Connection`` was bound to this ``Mapper`` - specifically within this ``Session``, return that ``Engine`` or - ``Connection``. - - 2. if an ``Engine`` or ``Connection`` was bound to this `mapper` 's - underlying ``Table`` within this ``Session`` (i.e. not to the ``Table`` - directly), return that ``Engine`` or ``Connection``. - - 3. if an ``Engine`` or ``Connection`` was bound to this ``Session``, - return that ``Engine`` or ``Connection``. - - 4. finally, return the ``Engine`` which was bound directly to the - ``Table`` 's ``MetaData`` object. - - If no ``Engine`` is bound to the ``Table``, an exception is raised. - """ - - if mapper is None: + if mapper is None and clause is None: if self.bind is not None: return self.bind else: raise exceptions.InvalidRequestError("This session is unbound to any Engine or Connection; specify a mapper to get_bind()") - elif self.__binds.has_key(mapper): - return self.__binds[mapper] - elif self.__binds.has_key(mapper.compile().mapped_table): - return self.__binds[mapper.mapped_table] - elif self.bind is not None: + + elif len(self.__binds): + if mapper is not None: + if isinstance(mapper, type): + mapper = _class_mapper(mapper) + if self.__binds.has_key(mapper.base_mapper): + return self.__binds[mapper.base_mapper] + elif self.__binds.has_key(mapper.compile().mapped_table): + return self.__binds[mapper.mapped_table] + if clause is not None: + for t in clause._table_iterator(): + if t in self.__binds: + return self.__binds[t] + + if self.bind is not None: return self.bind else: + if isinstance(mapper, type): + mapper = _class_mapper(mapper) e = mapper.mapped_table.bind if e is None: raise exceptions.InvalidRequestError("Could not locate any Engine or Connection bound to mapper '%s'" % str(mapper)) diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 43b95a0fdc..f2bc93d3a9 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -463,9 +463,9 @@ class EagerLoader(AbstractRelationLoader): # row-loading phase to match up AliasedClause objects with the current # LoaderStack position. if parentclauses: - path = parentclauses.path + (self.parent.base_mapper(), self.key) + path = parentclauses.path + (self.parent.base_mapper, self.key) else: - path = (self.parent.base_mapper(), self.key) + path = (self.parent.base_mapper, self.key) if self.join_depth: if len(path) / 2 > self.join_depth: diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index c0eebe3b0b..7acb26341f 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -307,7 +307,7 @@ class UOWTransaction(object): if dontcreate: return None - base_mapper = mapper.base_mapper() + base_mapper = mapper.base_mapper if base_mapper in self.tasks: base_task = self.tasks[base_mapper] else: @@ -336,8 +336,8 @@ class UOWTransaction(object): # also convert to the "base mapper", the parentmost task at the top of an inheritance chain # dependency sorting is done via non-inheriting mappers only, dependencies between mappers # in the same inheritance chain is done at the per-object level - mapper = mapper.primary_mapper().base_mapper() - dependency = dependency.primary_mapper().base_mapper() + mapper = mapper.primary_mapper().base_mapper + dependency = dependency.primary_mapper().base_mapper self.dependencies.add((mapper, dependency)) @@ -715,8 +715,8 @@ class UOWTask(object): return l def dependency_in_cycles(dep): - proctask = trans.get_task_by_mapper(dep.processor.mapper.base_mapper(), True) - targettask = trans.get_task_by_mapper(dep.targettask.mapper.base_mapper(), True) + proctask = trans.get_task_by_mapper(dep.processor.mapper.base_mapper, True) + targettask = trans.get_task_by_mapper(dep.targettask.mapper.base_mapper, True) return targettask in cycles and (proctask is not None and proctask in cycles) # organize all original UOWDependencyProcessors by their target task diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index ee64e82f27..3fc13a50dc 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -894,32 +894,41 @@ class ClauseVisitor(object): meth = getattr(self, "visit_%s" % obj.__visit_name__, None) if meth: return meth(obj, **kwargs) - + + def iterate(self, obj, stop_on=None): + stack = [obj] + traversal = [] + while len(stack) > 0: + t = stack.pop() + if stop_on is None or t not in stop_on: + yield t + traversal.insert(0, t) + for c in t.get_children(**self.__traverse_options__): + stack.append(c) + def traverse(self, obj, stop_on=None, clone=False): if clone: obj = obj._clone() - - v = self - visitors = [] - while v is not None: - visitors.append(v) - v = getattr(v, '_next', None) - - def _trav(obj): - if stop_on is not None and obj in stop_on: - return - if clone: - obj._copy_internals() - for c in obj.get_children(**self.__traverse_options__): - _trav(c) - - for v in visitors: - meth = getattr(v, "visit_%s" % obj.__visit_name__, None) + + stack = [obj] + traversal = [] + while len(stack) > 0: + t = stack.pop() + if stop_on is None or t not in stop_on: + traversal.insert(0, t) + if clone: + t._copy_internals() + for c in t.get_children(**self.__traverse_options__): + stack.append(c) + for target in traversal: + v = self + while v is not None: + meth = getattr(v, "visit_%s" % target.__visit_name__, None) if meth: - meth(obj) - _trav(obj) + meth(target) + v = getattr(v, '_next', None) return obj - + def chain(self, visitor): """'chain' an additional ClauseVisitor onto this ClauseVisitor. @@ -2070,6 +2079,9 @@ class _TextClause(ClauseElement): def supports_execution(self): return True + def _table_iterator(self): + return iter([]) + class _Null(ColumnElement): """Represent the NULL keyword in a SQL statement. @@ -2592,6 +2604,9 @@ class Alias(FromClause): def supports_execution(self): return self.original.supports_execution() + def _table_iterator(self): + return self.original._table_iterator() + def _locate_oid_column(self): if self.selectable.oid_column is not None: return self.selectable.oid_column._make_proxy(self) @@ -3065,6 +3080,11 @@ class CompoundSelect(_SelectBaseMixin, FromClause): def get_children(self, column_collections=True, **kwargs): return (column_collections and list(self.c) or []) + \ [self._order_by_clause, self._group_by_clause] + list(self.selects) + + def _table_iterator(self): + for s in self.selects: + for t in s._table_iterator(): + yield t def _find_engine(self): for s in self.selects: @@ -3334,6 +3354,11 @@ class Select(_SelectBaseMixin, FromClause): def intersect_all(self, other, **kwargs): return intersect_all(self, other, **kwargs) + def _table_iterator(self): + for t in NoColumnVisitor().iterate(self): + if isinstance(t, TableClause): + yield t + def _find_engine(self): """Try to return a Engine, either explicitly set in this object, or searched within the from clauses for one. @@ -3365,6 +3390,9 @@ class _UpdateBase(ClauseElement): def supports_execution(self): return True + def _table_iterator(self): + return iter([self.table]) + def _process_colparams(self, parameters): """Receive the *values* of an ``INSERT`` or ``UPDATE`` statement and construct appropriate bind parameters. diff --git a/test/orm/session.py b/test/orm/session.py index 4720593b6e..8e12b819d0 100644 --- a/test/orm/session.py +++ b/test/orm/session.py @@ -74,6 +74,22 @@ class SessionTest(AssertMixin): # then see if expunge fails session.expunge(u) + def test_binds_from_expression(self): + """test that Session can extract Table objects from ClauseElements and match them to tables.""" + Session = sessionmaker(binds={users:testbase.db, addresses:testbase.db}) + sess = Session() + sess.execute(users.insert(), params=dict(user_id=1, user_name='ed')) + assert sess.execute(users.select()).fetchall() == [(1, 'ed')] + + mapper(Address, addresses) + mapper(User, users, properties={ + 'addresses':relation(Address, backref=backref("user", cascade="all"), cascade="all") + }) + Session = sessionmaker(binds={User:testbase.db, Address:testbase.db}) + sess.execute(users.insert(), params=dict(user_id=2, user_name='fred')) + assert sess.execute(users.select()).fetchall() == [(1, 'ed'), (2, 'fred')] + + @testing.unsupported('sqlite', 'mssql') # TEMP: test causes mssql to hang def test_transaction(self): class User(object):pass diff --git a/test/testlib/testing.py b/test/testlib/testing.py index 6c1d7ad810..9ee2012021 100644 --- a/test/testlib/testing.py +++ b/test/testlib/testing.py @@ -6,7 +6,7 @@ import testbase import unittest, re, sys, os, operator from cStringIO import StringIO import testlib.config as config -sql, MetaData, clear_mappers = None, None, None +sql, MetaData, clear_mappers, Session = None, None, None, None __all__ = ('PersistTest', 'AssertMixin', 'ORMTest', 'SQLCompileTest') @@ -323,6 +323,10 @@ class ORMTest(AssertMixin): _otest_metadata.drop_all() def tearDown(self): + global Session + if Session is None: + from sqlalchemy.orm.session import Session + Session.close_all() global clear_mappers if clear_mappers is None: from sqlalchemy.orm import clear_mappers