From 2a171e818c2e7cfadcd286399a3740147ff0df45 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 6 Apr 2006 23:46:02 +0000 Subject: [PATCH] split up Session into Session/LegacySession, added some new constructor args created AbstractEngine class which provides base for SQLEngine and will also provide base for ConnectionProxy, so SQL binding can be to an engine or specific connection resource ClauseElements get using() method which can take AbstractEngines for execution made more separation between SchemaItems and bound engine --- lib/sqlalchemy/ext/proxy.py | 10 +- lib/sqlalchemy/mapping/objectstore.py | 209 ++++++++++++++------------ lib/sqlalchemy/mapping/query.py | 5 +- lib/sqlalchemy/schema.py | 55 +++++-- lib/sqlalchemy/sql.py | 66 ++++---- 5 files changed, 204 insertions(+), 141 deletions(-) diff --git a/lib/sqlalchemy/ext/proxy.py b/lib/sqlalchemy/ext/proxy.py index 38325bea35..a24f089e9c 100644 --- a/lib/sqlalchemy/ext/proxy.py +++ b/lib/sqlalchemy/ext/proxy.py @@ -24,7 +24,15 @@ class BaseProxyEngine(schema.SchemaEngine): def reflecttable(self, table): return self.get_engine().reflecttable(table) - + def execute_compiled(self, *args, **kwargs): + return self.get_engine().execute_compiled(*args, **kwargs) + def compiler(self, *args, **kwargs): + return self.get_engine().compiler(*args, **kwargs) + def schemagenerator(self, *args, **kwargs): + return self.get_engine().schemagenerator(*args, **kwargs) + def schemadropper(self, *args, **kwargs): + return self.get_engine().schemadropper(*args, **kwargs) + def hash_key(self): return "%s(%s)" % (self.__class__.__name__, id(self)) diff --git a/lib/sqlalchemy/mapping/objectstore.py b/lib/sqlalchemy/mapping/objectstore.py index 1491d39ac0..faf5ddbd6b 100644 --- a/lib/sqlalchemy/mapping/objectstore.py +++ b/lib/sqlalchemy/mapping/objectstore.py @@ -17,7 +17,7 @@ import sqlalchemy class Session(object): """Maintains a UnitOfWork instance, including transaction state.""" - def __init__(self, nest_on=None, hash_key=None): + def __init__(self, hash_key=None, new_imap=True, import_session=None): """Initialize the objectstore with a UnitOfWork registry. If called with no arguments, creates a single UnitOfWork for all operations. @@ -26,31 +26,23 @@ class Session(object): hash_key - the hash_key used to identify objects against this session, which defaults to the id of the Session instance. """ - self.uow = unitofwork.UnitOfWork() - self.parent_uow = None - self.begin_count = 0 - self.nest_on = util.to_list(nest_on) - self.__pushed_count = 0 + if import_session is not None: + self.uow = unitofwork.UnitOfWork(identity_map=import_session.uow.identity_map) + elif new_imap is False: + self.uow = unitofwork.UnitOfWork(identity_map=objectstore.get_session().uow.identity_map) + else: + self.uow = unitofwork.UnitOfWork() + + self.binds = {} if hash_key is None: self.hash_key = id(self) else: self.hash_key = hash_key _sessions[self.hash_key] = self - def was_pushed(self): - if self.nest_on is None: - return - self.__pushed_count += 1 - if self.__pushed_count == 1: - for n in self.nest_on: - n.push_session() - def was_popped(self): - if self.nest_on is None or self.__pushed_count == 0: - return - self.__pushed_count -= 1 - if self.__pushed_count == 0: - for n in self.nest_on: - n.pop_session() + def bind_table(self, table, bindto): + self.binds[table] = bindto + def get_id_key(ident, class_, entity_name=None): """returns an identity-map key for use in storing/retrieving an item from the identity map, given a tuple of the object's primary key values. @@ -81,79 +73,12 @@ class Session(object): """ return (class_, tuple([row[column] for column in primary_key]), entity_name) get_row_key = staticmethod(get_row_key) - - class SessionTrans(object): - """returned by Session.begin(), denotes a transactionalized UnitOfWork instance. - call commit() on this to commit the transaction.""" - def __init__(self, parent, uow, isactive): - self.__parent = parent - self.__isactive = isactive - self.__uow = uow - isactive = property(lambda s:s.__isactive, doc="True if this SessionTrans is the 'active' transaction marker, else its a no-op.") - parent = property(lambda s:s.__parent, doc="returns the parent Session of this SessionTrans object.") - uow = property(lambda s:s.__uow, doc="returns the parent UnitOfWork corresponding to this transaction.") - def begin(self): - """calls begin() on the underlying Session object, returning a new no-op SessionTrans object.""" - if self.parent.uow is not self.uow: - raise InvalidRequestError("This SessionTrans is no longer valid") - return self.parent.begin() - def commit(self): - """commits the transaction noted by this SessionTrans object.""" - self.__parent._trans_commit(self) - self.__isactive = False - def rollback(self): - """rolls back the current UnitOfWork transaction, in the case that begin() - has been called. The changes logged since the begin() call are discarded.""" - self.__parent._trans_rollback(self) - self.__isactive = False - - def begin(self): - """begins a new UnitOfWork transaction and returns a tranasaction-holding - object. commit() or rollback() should be called on the returned object. - commit() on the Session will do nothing while a transaction is pending, and further - calls to begin() will return no-op transactional objects.""" - if self.parent_uow is not None: - return Session.SessionTrans(self, self.uow, False) - self.parent_uow = self.uow - self.uow = unitofwork.UnitOfWork(identity_map = self.uow.identity_map) - return Session.SessionTrans(self, self.uow, True) def engines(self, mapper): return [t.engine for t in mapper.tables] - def _trans_commit(self, trans): - if trans.uow is self.uow and trans.isactive: - try: - self._commit_uow() - finally: - self.uow = self.parent_uow - self.parent_uow = None - def _trans_rollback(self, trans): - if trans.uow is self.uow: - self.uow = self.parent_uow - self.parent_uow = None - - def _commit_uow(self, *obj): - self.was_pushed() - try: - self.uow.flush(self, *obj) - finally: - self.was_popped() - - def commit(self, *objects): - """commits the current UnitOfWork transaction. called with - no arguments, this is only used - for "implicit" transactions when there was no begin(). - if individual objects are submitted, then only those objects are committed, and the - begin/commit cycle is not affected.""" - # if an object list is given, commit just those but dont - # change begin/commit status - if len(objects): - self._commit_uow(*objects) - self.uow.flush(self, *objects) - return - if self.parent_uow is None: - self._commit_uow() + def flush(self, *obj): + self.uow.flush(self, *obj) def refresh(self, *obj): """reloads the attributes for the given objects from the database, clears @@ -221,6 +146,95 @@ class Session(object): u.register_new(instance) return instance +class LegacySession(Session): + def __init__(self, nest_on=None, hash_key=None, **kwargs): + super(LegacySession, self).__init__(**kwargs) + self.parent_uow = None + self.begin_count = 0 + self.nest_on = util.to_list(nest_on) + self.__pushed_count = 0 + def was_pushed(self): + if self.nest_on is None: + return + self.__pushed_count += 1 + if self.__pushed_count == 1: + for n in self.nest_on: + n.push_session() + def was_popped(self): + if self.nest_on is None or self.__pushed_count == 0: + return + self.__pushed_count -= 1 + if self.__pushed_count == 0: + for n in self.nest_on: + n.pop_session() + class SessionTrans(object): + """returned by Session.begin(), denotes a transactionalized UnitOfWork instance. + call commit() on this to commit the transaction.""" + def __init__(self, parent, uow, isactive): + self.__parent = parent + self.__isactive = isactive + self.__uow = uow + isactive = property(lambda s:s.__isactive, doc="True if this SessionTrans is the 'active' transaction marker, else its a no-op.") + parent = property(lambda s:s.__parent, doc="returns the parent Session of this SessionTrans object.") + uow = property(lambda s:s.__uow, doc="returns the parent UnitOfWork corresponding to this transaction.") + def begin(self): + """calls begin() on the underlying Session object, returning a new no-op SessionTrans object.""" + if self.parent.uow is not self.uow: + raise InvalidRequestError("This SessionTrans is no longer valid") + return self.parent.begin() + def commit(self): + """commits the transaction noted by this SessionTrans object.""" + self.__parent._trans_commit(self) + self.__isactive = False + def rollback(self): + """rolls back the current UnitOfWork transaction, in the case that begin() + has been called. The changes logged since the begin() call are discarded.""" + self.__parent._trans_rollback(self) + self.__isactive = False + def begin(self): + """begins a new UnitOfWork transaction and returns a tranasaction-holding + object. commit() or rollback() should be called on the returned object. + commit() on the Session will do nothing while a transaction is pending, and further + calls to begin() will return no-op transactional objects.""" + if self.parent_uow is not None: + return Session.SessionTrans(self, self.uow, False) + self.parent_uow = self.uow + self.uow = unitofwork.UnitOfWork(identity_map = self.uow.identity_map) + return Session.SessionTrans(self, self.uow, True) + def commit(self, *objects): + """commits the current UnitOfWork transaction. called with + no arguments, this is only used + for "implicit" transactions when there was no begin(). + if individual objects are submitted, then only those objects are committed, and the + begin/commit cycle is not affected.""" + # if an object list is given, commit just those but dont + # change begin/commit status + if len(objects): + self._commit_uow(*objects) + self.uow.flush(self, *objects) + return + if self.parent_uow is None: + self._commit_uow() + def _trans_commit(self, trans): + if trans.uow is self.uow and trans.isactive: + try: + self._commit_uow() + finally: + self.uow = self.parent_uow + self.parent_uow = None + def _trans_rollback(self, trans): + if trans.uow is self.uow: + self.uow = self.parent_uow + self.parent_uow = None + def _commit_uow(self, *obj): + self.was_pushed() + try: + self.uow.flush(self, *obj) + finally: + self.was_popped() + +Session = LegacySession + def get_id_key(ident, class_, entity_name=None): return Session.get_id_key(ident, class_, entity_name) @@ -228,19 +242,22 @@ def get_row_key(row, class_, primary_key, entity_name=None): return Session.get_row_key(row, class_, primary_key, entity_name) def begin(): - """begins a new UnitOfWork transaction. the next commit will affect only - objects that are created, modified, or deleted following the begin statement.""" + """deprecated. use s = Session(new_imap=False).""" return get_session().begin() def commit(*obj): - """commits the current UnitOfWork transaction. if a transaction was begun - via begin(), commits only those objects that were created, modified, or deleted - since that begin statement. otherwise commits all objects that have been + """deprecated; use flush(*obj)""" + get_session().flush(*obj) + +def flush(*obj): + """flushes the current UnitOfWork transaction. if a transaction was begun + via begin(), flushes only those objects that were created, modified, or deleted + since that begin statement. otherwise flushes all objects that have been changed. - + if individual objects are submitted, then only those objects are committed, and the begin/commit cycle is not affected.""" - get_session().commit(*obj) + get_session().flush(*obj) def clear(): """removes all current UnitOfWorks and IdentityMaps for this thread and diff --git a/lib/sqlalchemy/mapping/query.py b/lib/sqlalchemy/mapping/query.py index 09c2b9b6ec..950c2be42c 100644 --- a/lib/sqlalchemy/mapping/query.py +++ b/lib/sqlalchemy/mapping/query.py @@ -10,6 +10,7 @@ class Query(object): self.mapper = mapper self.always_refresh = kwargs.pop('always_refresh', self.mapper.always_refresh) self.order_by = kwargs.pop('order_by', self.mapper.order_by) + self.extension = kwargs.pop('extension', self.mapper.extension) self._session = kwargs.pop('session', None) if not hasattr(mapper, '_get_clause'): _get_clause = sql.and_() @@ -66,7 +67,7 @@ class Query(object): e.g. result = usermapper.select_by(user_name = 'fred') """ - ret = self.mapper.extension.select_by(self, *args, **params) + ret = self.extension.select_by(self, *args, **params) if ret is not mapper.EXT_PASS: return ret return self.select_whereclause(self._by_clause(*args, **params)) @@ -116,7 +117,7 @@ class Query(object): in this case, the developer must insure that an adequate set of columns exists in the rowset with which to build new object instances.""" - ret = self.mapper.extension.select(self, arg=arg, **kwargs) + ret = self.extension.select(self, arg=arg, **kwargs) if ret is not mapper.EXT_PASS: return ret elif arg is not None and isinstance(arg, sql.Selectable): diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 24392b3d97..acce555ab2 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -23,8 +23,17 @@ import copy, re, string __all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', 'SchemaEngine', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault'] +class SchemaMeta(type): + """provides universal constructor arguments for all SchemaItems""" + def __call__(self, *args, **kwargs): + engine = kwargs.pop('engine', None) + obj = type.__call__(self, *args, **kwargs) + obj._engine = engine + return obj + class SchemaItem(object): """base class for items that define a database schema.""" + __metaclass__ = SchemaMeta def _init_items(self, *args): for item in args: if item is not None: @@ -34,7 +43,20 @@ class SchemaItem(object): raise NotImplementedError() def __repr__(self): return "%s()" % self.__class__.__name__ - + +class EngineMixin(object): + """a mixin for SchemaItems that provides an "engine" accessor.""" + def _derived_engine(self): + """subclasses override this method to return an AbstractEngine + bound to a parent item""" + return None + def _get_engine(self): + if self._engine is not None: + return self._engine + else: + return self._derived_engine() + engine = property(_get_engine) + def _get_table_key(engine, name, schema): if schema is not None and schema == engine.get_default_schema_name(): schema = None @@ -43,14 +65,12 @@ def _get_table_key(engine, name, schema): else: return schema + "." + name -class TableSingleton(type): +class TableSingleton(SchemaMeta): """a metaclass used by the Table object to provide singleton behavior.""" def __call__(self, name, engine=None, *args, **kwargs): try: - if not isinstance(engine, SchemaEngine): + if engine is not None and not isinstance(engine, SchemaEngine): args = [engine] + list(args) - engine = None - if engine is None: engine = default_engine name = str(name) # in case of incoming unicode schema = kwargs.get('schema', None) @@ -58,6 +78,10 @@ class TableSingleton(type): redefine = kwargs.pop('redefine', False) mustexist = kwargs.pop('mustexist', False) useexisting = kwargs.pop('useexisting', False) + if not engine: + table = type.__call__(self, name, engine, **kwargs) + table._init_items(*args) + return table key = _get_table_key(engine, name, schema) table = engine.tables[key] if len(args): @@ -440,15 +464,14 @@ class ForeignKey(SchemaItem): self.parent.foreign_key = self self.parent.table.foreign_keys.append(self) -class DefaultGenerator(SchemaItem): +class DefaultGenerator(SchemaItem, EngineMixin): """Base class for column "default" values.""" - def __init__(self, for_update=False, engine=None): + def __init__(self, for_update=False): self.for_update = for_update - self.engine = engine + def _derived_engine(self): + return self.column.table.engine def _set_parent(self, column): self.column = column - if self.engine is None: - self.engine = column.table.engine if self.for_update: self.column.onupdate = self else: @@ -509,7 +532,7 @@ class Sequence(DefaultGenerator): return visitor.visit_sequence(self) -class Index(SchemaItem): +class Index(SchemaItem, EngineMixin): """Represents an index of columns from a database table """ def __init__(self, name, *columns, **kw): @@ -530,7 +553,8 @@ class Index(SchemaItem): self.unique = kw.pop('unique', False) self._init_items(*columns) - engine = property(lambda s:s.table.engine) + def _derived_engine(self): + return self.table.engine def _init_items(self, *args): for column in args: self.append_column(column) @@ -570,18 +594,21 @@ class Index(SchemaItem): for c in self.columns]), (self.unique and ', unique=True') or '') -class SchemaEngine(object): +class SchemaEngine(sql.AbstractEngine): """a factory object used to create implementations for schema objects. This object is the ultimate base class for the engine.SQLEngine class.""" def __init__(self): # a dictionary that stores Table objects keyed off their name (and possibly schema name) self.tables = {} - def reflecttable(self, table): """given a table, will query the database and populate its Column and ForeignKey objects.""" raise NotImplementedError() + def schemagenerator(self, **params): + raise NotImplementedError() + def schemadropper(self, **params): + raise NotImplementedError() class SchemaVisitor(sql.ClauseVisitor): """defines the visiting for SchemaItem objects""" diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index f6e2d03c9a..2bc025e9f0 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -246,6 +246,12 @@ def _is_literal(element): def is_column(col): return isinstance(col, ColumnElement) +class AbstractEngine(object): + def execute_compiled(self, compiled, parameters, echo=None, **kwargs): + raise NotImplementedError() + def compiler(self, statement, parameters, **kwargs): + raise NotImplementedError() + class ClauseParameters(util.OrderedDict): """represents a dictionary/iterator of bind parameter key names/values. Includes parameters compiled with a Compiled object as well as additional arguments passed to the Compiled object's get_params() method. Parameter values will be converted as per the TypeEngine objects present in the bind parameter objects. The non-converted value can be retrieved via the get_original method. For Compiled objects that compile positional parameters, the values() iteration of the object will return the parameter values in the correct order.""" def __init__(self, engine=None): @@ -340,8 +346,11 @@ class Compiled(ClauseVisitor): """executes this compiled object using the underlying SQLEngine""" if len(multiparams): params = multiparams - - return self.engine.execute_compiled(self, params) + + e = self.engine + if e is None: + raise InvalidRequestError("This Compiled object is not bound to any engine.") + return e.execute_compiled(self, params) def scalar(self, *multiparams, **params): """executes this compiled object via the execute() method, then @@ -356,7 +365,26 @@ class Compiled(ClauseVisitor): return row[0] else: return None - + +class Executor(object): + """handles the compilation/execution of a ClauseElement within the context of a particular AbtractEngine. This + AbstractEngine will usually be a SQLEngine or ConnectionProxy.""" + def __init__(self, clauseelement, abstractengine=None): + self.engine=abstractengine + self.clauseelement = clauseelement + def execute(self, *multiparams, **params): + return self.compile(*multiparams, **params).execute(*multiparams, **params) + def scalar(self, *multiparams, **params): + return self.compile(*multiparams, **params).scalar(*multiparams, **params) + def compile(self, *multiparams, **params): + if len(multiparams): + bindparams = multiparams[0] + else: + bindparams = params + compiler = self.engine.compiler(self.clauseelement, bindparams) + compiler.compile() + return compiler + class ClauseElement(object): """base class for elements of a programmatically constructed SQL expression.""" def _get_from_objects(self): @@ -415,10 +443,12 @@ class ClauseElement(object): engine = property(lambda s: s._find_engine(), doc="attempts to locate a SQLEngine within this ClauseElement structure, or returns None if none found.") - + def using(self, abstractengine): + return Executor(self, abstractengine) + def compile(self, engine = None, parameters = None, typemap=None, compiler=None): """compiles this SQL expression using its underlying SQLEngine to produce - a Compiled object. If no engine can be found, an ansisql engine is used. + a Compiled object. If no engine can be found, an ANSICompiler is used with no engine. bindparams is a dictionary representing the default bind parameters to be used with the statement. """ @@ -430,7 +460,7 @@ class ClauseElement(object): if compiler is None: import sqlalchemy.ansisql as ansisql - compiler = ansisql.ANSICompiler(self, parameters=parameters, typemap=typemap) + compiler = ansisql.ANSICompiler(self, parameters=parameters) compiler.compile() return compiler @@ -438,30 +468,10 @@ class ClauseElement(object): return str(self.compile()) def execute(self, *multiparams, **params): - """compiles and executes this SQL expression using its underlying SQLEngine. the - given **params are used as bind parameters when compiling and executing the - expression. the DBAPI cursor object is returned.""" - e = self.engine - if len(multiparams): - bindparams = multiparams[0] - else: - bindparams = params - c = self.compile(e, parameters=bindparams) - return c.execute(*multiparams, **params) + return self.using(self.engine).execute(*multiparams, **params) def scalar(self, *multiparams, **params): - """executes this SQL expression via the execute() method, then - returns the first column of the first row. Useful for executing functions, - sequences, rowcounts, etc.""" - # we are still going off the assumption that fetching only the first row - # in a result set is not performance-wise any different than specifying limit=1 - # else we'd have to construct a copy of the select() object with the limit - # installed (else if we change the existing select, not threadsafe) - row = self.execute(*multiparams, **params).fetchone() - if row is not None: - return row[0] - else: - return None + return self.using(self.engine).scalar(*multiparams, **params) def __and__(self, other): return and_(self, other) -- 2.47.2