]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
split up Session into Session/LegacySession, added some new constructor args
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 6 Apr 2006 23:46:02 +0000 (23:46 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 6 Apr 2006 23:46:02 +0000 (23:46 +0000)
created AbstractEngine class which provides base for SQLEngine and will also
provide base for ConnectionProxy, so SQL binding can be to an engine or specific
connection resource
ClauseElements get using() method which can take AbstractEngines for execution
made more separation between SchemaItems and bound engine

lib/sqlalchemy/ext/proxy.py
lib/sqlalchemy/mapping/objectstore.py
lib/sqlalchemy/mapping/query.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py

index 38325bea35e38de267bd7185c8cedd5c54269c05..a24f089e9cf70d7a327d3a6cbdd86d4f09c0c667 100644 (file)
@@ -24,7 +24,15 @@ class BaseProxyEngine(schema.SchemaEngine):
 
     def reflecttable(self, table):
         return self.get_engine().reflecttable(table)
-        
+    def execute_compiled(self, *args, **kwargs):
+        return self.get_engine().execute_compiled(*args, **kwargs)
+    def compiler(self, *args, **kwargs):
+        return self.get_engine().compiler(*args, **kwargs)
+    def schemagenerator(self, *args, **kwargs):
+        return self.get_engine().schemagenerator(*args, **kwargs)
+    def schemadropper(self, *args, **kwargs):
+        return self.get_engine().schemadropper(*args, **kwargs)
+            
     def hash_key(self):
         return "%s(%s)" % (self.__class__.__name__, id(self))
 
index 1491d39ac0e0e750d050bee41308236315169632..faf5ddbd6b1796fcf78784aee2c7a1241009a093 100644 (file)
@@ -17,7 +17,7 @@ import sqlalchemy
 class Session(object):
     """Maintains a UnitOfWork instance, including transaction state."""
     
-    def __init__(self, nest_on=None, hash_key=None):
+    def __init__(self, hash_key=None, new_imap=True, import_session=None):
         """Initialize the objectstore with a UnitOfWork registry.  If called
         with no arguments, creates a single UnitOfWork for all operations.
         
@@ -26,31 +26,23 @@ class Session(object):
         hash_key - the hash_key used to identify objects against this session, which 
         defaults to the id of the Session instance.
         """
-        self.uow = unitofwork.UnitOfWork()
-        self.parent_uow = None
-        self.begin_count = 0
-        self.nest_on = util.to_list(nest_on)
-        self.__pushed_count = 0
+        if import_session is not None:
+            self.uow = unitofwork.UnitOfWork(identity_map=import_session.uow.identity_map)
+        elif new_imap is False:
+            self.uow = unitofwork.UnitOfWork(identity_map=objectstore.get_session().uow.identity_map)
+        else:
+            self.uow = unitofwork.UnitOfWork()
+            
+        self.binds = {}
         if hash_key is None:
             self.hash_key = id(self)
         else:
             self.hash_key = hash_key
         _sessions[self.hash_key] = self
     
-    def was_pushed(self):
-        if self.nest_on is None:
-            return
-        self.__pushed_count += 1
-        if self.__pushed_count == 1:
-            for n in self.nest_on:
-                n.push_session()
-    def was_popped(self):
-        if self.nest_on is None or self.__pushed_count == 0:
-            return
-        self.__pushed_count -= 1
-        if self.__pushed_count == 0:
-            for n in self.nest_on:
-                n.pop_session()
+    def bind_table(self, table, bindto):
+        self.binds[table] = bindto
+                
     def get_id_key(ident, class_, entity_name=None):
         """returns an identity-map key for use in storing/retrieving an item from the identity
         map, given a tuple of the object's primary key values.
@@ -81,79 +73,12 @@ class Session(object):
         """
         return (class_, tuple([row[column] for column in primary_key]), entity_name)
     get_row_key = staticmethod(get_row_key)
-
-    class SessionTrans(object):
-        """returned by Session.begin(), denotes a transactionalized UnitOfWork instance.
-        call commit() on this to commit the transaction."""
-        def __init__(self, parent, uow, isactive):
-            self.__parent = parent
-            self.__isactive = isactive
-            self.__uow = uow
-        isactive = property(lambda s:s.__isactive, doc="True if this SessionTrans is the 'active' transaction marker, else its a no-op.")
-        parent = property(lambda s:s.__parent, doc="returns the parent Session of this SessionTrans object.")
-        uow = property(lambda s:s.__uow, doc="returns the parent UnitOfWork corresponding to this transaction.")
-        def begin(self):
-            """calls begin() on the underlying Session object, returning a new no-op SessionTrans object."""
-            if self.parent.uow is not self.uow:
-                raise InvalidRequestError("This SessionTrans is no longer valid")
-            return self.parent.begin()
-        def commit(self):
-            """commits the transaction noted by this SessionTrans object."""
-            self.__parent._trans_commit(self)
-            self.__isactive = False
-        def rollback(self):
-            """rolls back the current UnitOfWork transaction, in the case that begin()
-            has been called.  The changes logged since the begin() call are discarded."""
-            self.__parent._trans_rollback(self)
-            self.__isactive = False
-
-    def begin(self):
-        """begins a new UnitOfWork transaction and returns a tranasaction-holding
-        object.  commit() or rollback() should be called on the returned object.
-        commit() on the Session will do nothing while a transaction is pending, and further
-        calls to begin() will return no-op transactional objects."""
-        if self.parent_uow is not None:
-            return Session.SessionTrans(self, self.uow, False)
-        self.parent_uow = self.uow
-        self.uow = unitofwork.UnitOfWork(identity_map = self.uow.identity_map)
-        return Session.SessionTrans(self, self.uow, True)
     
     def engines(self, mapper):
         return [t.engine for t in mapper.tables]
         
-    def _trans_commit(self, trans):
-        if trans.uow is self.uow and trans.isactive:
-            try:
-                self._commit_uow()
-            finally:
-                self.uow = self.parent_uow
-                self.parent_uow = None
-    def _trans_rollback(self, trans):
-        if trans.uow is self.uow:
-            self.uow = self.parent_uow
-            self.parent_uow = None
-
-    def _commit_uow(self, *obj):
-        self.was_pushed()
-        try:
-            self.uow.flush(self, *obj)
-        finally:
-            self.was_popped()
-                        
-    def commit(self, *objects):
-        """commits the current UnitOfWork transaction.  called with
-        no arguments, this is only used
-        for "implicit" transactions when there was no begin().
-        if individual objects are submitted, then only those objects are committed, and the 
-        begin/commit cycle is not affected."""
-        # if an object list is given, commit just those but dont
-        # change begin/commit status
-        if len(objects):
-            self._commit_uow(*objects)
-            self.uow.flush(self, *objects)
-            return
-        if self.parent_uow is None:
-            self._commit_uow()
+    def flush(self, *obj):
+        self.uow.flush(self, *obj)
             
     def refresh(self, *obj):
         """reloads the attributes for the given objects from the database, clears
@@ -221,6 +146,95 @@ class Session(object):
             u.register_new(instance)
         return instance
 
+class LegacySession(Session):
+    def __init__(self, nest_on=None, hash_key=None, **kwargs):
+        super(LegacySession, self).__init__(**kwargs)
+        self.parent_uow = None
+        self.begin_count = 0
+        self.nest_on = util.to_list(nest_on)
+        self.__pushed_count = 0
+    def was_pushed(self):
+        if self.nest_on is None:
+            return
+        self.__pushed_count += 1
+        if self.__pushed_count == 1:
+            for n in self.nest_on:
+                n.push_session()
+    def was_popped(self):
+        if self.nest_on is None or self.__pushed_count == 0:
+            return
+        self.__pushed_count -= 1
+        if self.__pushed_count == 0:
+            for n in self.nest_on:
+                n.pop_session()
+    class SessionTrans(object):
+        """returned by Session.begin(), denotes a transactionalized UnitOfWork instance.
+        call commit() on this to commit the transaction."""
+        def __init__(self, parent, uow, isactive):
+            self.__parent = parent
+            self.__isactive = isactive
+            self.__uow = uow
+        isactive = property(lambda s:s.__isactive, doc="True if this SessionTrans is the 'active' transaction marker, else its a no-op.")
+        parent = property(lambda s:s.__parent, doc="returns the parent Session of this SessionTrans object.")
+        uow = property(lambda s:s.__uow, doc="returns the parent UnitOfWork corresponding to this transaction.")
+        def begin(self):
+            """calls begin() on the underlying Session object, returning a new no-op SessionTrans object."""
+            if self.parent.uow is not self.uow:
+                raise InvalidRequestError("This SessionTrans is no longer valid")
+            return self.parent.begin()
+        def commit(self):
+            """commits the transaction noted by this SessionTrans object."""
+            self.__parent._trans_commit(self)
+            self.__isactive = False
+        def rollback(self):
+            """rolls back the current UnitOfWork transaction, in the case that begin()
+            has been called.  The changes logged since the begin() call are discarded."""
+            self.__parent._trans_rollback(self)
+            self.__isactive = False
+    def begin(self):
+        """begins a new UnitOfWork transaction and returns a tranasaction-holding
+        object.  commit() or rollback() should be called on the returned object.
+        commit() on the Session will do nothing while a transaction is pending, and further
+        calls to begin() will return no-op transactional objects."""
+        if self.parent_uow is not None:
+            return Session.SessionTrans(self, self.uow, False)
+        self.parent_uow = self.uow
+        self.uow = unitofwork.UnitOfWork(identity_map = self.uow.identity_map)
+        return Session.SessionTrans(self, self.uow, True)
+    def commit(self, *objects):
+        """commits the current UnitOfWork transaction.  called with
+        no arguments, this is only used
+        for "implicit" transactions when there was no begin().
+        if individual objects are submitted, then only those objects are committed, and the 
+        begin/commit cycle is not affected."""
+        # if an object list is given, commit just those but dont
+        # change begin/commit status
+        if len(objects):
+            self._commit_uow(*objects)
+            self.uow.flush(self, *objects)
+            return
+        if self.parent_uow is None:
+            self._commit_uow()
+    def _trans_commit(self, trans):
+        if trans.uow is self.uow and trans.isactive:
+            try:
+                self._commit_uow()
+            finally:
+                self.uow = self.parent_uow
+                self.parent_uow = None
+    def _trans_rollback(self, trans):
+        if trans.uow is self.uow:
+            self.uow = self.parent_uow
+            self.parent_uow = None
+    def _commit_uow(self, *obj):
+        self.was_pushed()
+        try:
+            self.uow.flush(self, *obj)
+        finally:
+            self.was_popped()
+
+Session = LegacySession
+
 def get_id_key(ident, class_, entity_name=None):
     return Session.get_id_key(ident, class_, entity_name)
 
@@ -228,19 +242,22 @@ def get_row_key(row, class_, primary_key, entity_name=None):
     return Session.get_row_key(row, class_, primary_key, entity_name)
 
 def begin():
-    """begins a new UnitOfWork transaction.  the next commit will affect only
-    objects that are created, modified, or deleted following the begin statement."""
+    """deprecated.  use s = Session(new_imap=False)."""
     return get_session().begin()
 
 def commit(*obj):
-    """commits the current UnitOfWork transaction.  if a transaction was begun 
-    via begin(), commits only those objects that were created, modified, or deleted
-    since that begin statement.  otherwise commits all objects that have been
+    """deprecated; use flush(*obj)"""
+    get_session().flush(*obj)
+
+def flush(*obj):
+    """flushes the current UnitOfWork transaction.  if a transaction was begun 
+    via begin(), flushes only those objects that were created, modified, or deleted
+    since that begin statement.  otherwise flushes all objects that have been
     changed.
-    
+
     if individual objects are submitted, then only those objects are committed, and the 
     begin/commit cycle is not affected."""
-    get_session().commit(*obj)
+    get_session().flush(*obj)
 
 def clear():
     """removes all current UnitOfWorks and IdentityMaps for this thread and 
index 09c2b9b6ec59bd93fbeec1ddfe5aa6c5a58d7b65..950c2be42c034da3b014f909758503f24f9a49cc 100644 (file)
@@ -10,6 +10,7 @@ class Query(object):
         self.mapper = mapper
         self.always_refresh = kwargs.pop('always_refresh', self.mapper.always_refresh)
         self.order_by = kwargs.pop('order_by', self.mapper.order_by)
+        self.extension = kwargs.pop('extension', self.mapper.extension)
         self._session = kwargs.pop('session', None)
         if not hasattr(mapper, '_get_clause'):
             _get_clause = sql.and_()
@@ -66,7 +67,7 @@ class Query(object):
 
         e.g.   result = usermapper.select_by(user_name = 'fred')
         """
-        ret = self.mapper.extension.select_by(self, *args, **params)
+        ret = self.extension.select_by(self, *args, **params)
         if ret is not mapper.EXT_PASS:
             return ret
         return self.select_whereclause(self._by_clause(*args, **params))
@@ -116,7 +117,7 @@ class Query(object):
         in this case, the developer must insure that an adequate set of columns exists in the 
         rowset with which to build new object instances."""
 
-        ret = self.mapper.extension.select(self, arg=arg, **kwargs)
+        ret = self.extension.select(self, arg=arg, **kwargs)
         if ret is not mapper.EXT_PASS:
             return ret
         elif arg is not None and isinstance(arg, sql.Selectable):
index 24392b3d973932376f993179896987c3f44427e9..acce555ab27ea8c4282c4111c2226c4b174b2613 100644 (file)
@@ -23,8 +23,17 @@ import copy, re, string
 __all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index',
            'SchemaEngine', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault']
 
+class SchemaMeta(type):
+    """provides universal constructor arguments for all SchemaItems"""
+    def __call__(self, *args, **kwargs):
+        engine = kwargs.pop('engine', None)
+        obj = type.__call__(self, *args, **kwargs)
+        obj._engine = engine
+        return obj
+    
 class SchemaItem(object):
     """base class for items that define a database schema."""
+    __metaclass__ = SchemaMeta
     def _init_items(self, *args):
         for item in args:
             if item is not None:
@@ -34,7 +43,20 @@ class SchemaItem(object):
         raise NotImplementedError()
     def __repr__(self):
         return "%s()" % self.__class__.__name__
-
+        
+class EngineMixin(object):
+    """a mixin for SchemaItems that provides an "engine" accessor."""
+    def _derived_engine(self):
+        """subclasses override this method to return an AbstractEngine
+        bound to a parent item"""
+        return None
+    def _get_engine(self):
+        if self._engine is not None:
+            return self._engine
+        else:
+            return self._derived_engine()
+    engine = property(_get_engine)
+    
 def _get_table_key(engine, name, schema):
     if schema is not None and schema == engine.get_default_schema_name():
         schema = None
@@ -43,14 +65,12 @@ def _get_table_key(engine, name, schema):
     else:
         return schema + "." + name
         
-class TableSingleton(type):
+class TableSingleton(SchemaMeta):
     """a metaclass used by the Table object to provide singleton behavior."""
     def __call__(self, name, engine=None, *args, **kwargs):
         try:
-            if not isinstance(engine, SchemaEngine):
+            if engine is not None and not isinstance(engine, SchemaEngine):
                 args = [engine] + list(args)
-                engine = None
-            if engine is None:
                 engine = default_engine
             name = str(name)    # in case of incoming unicode
             schema = kwargs.get('schema', None)
@@ -58,6 +78,10 @@ class TableSingleton(type):
             redefine = kwargs.pop('redefine', False)
             mustexist = kwargs.pop('mustexist', False)
             useexisting = kwargs.pop('useexisting', False)
+            if not engine:
+                table = type.__call__(self, name, engine, **kwargs)
+                table._init_items(*args)
+                return table
             key = _get_table_key(engine, name, schema)
             table = engine.tables[key]
             if len(args):
@@ -440,15 +464,14 @@ class ForeignKey(SchemaItem):
         self.parent.foreign_key = self
         self.parent.table.foreign_keys.append(self)
 
-class DefaultGenerator(SchemaItem):
+class DefaultGenerator(SchemaItem, EngineMixin):
     """Base class for column "default" values."""
-    def __init__(self, for_update=False, engine=None):
+    def __init__(self, for_update=False):
         self.for_update = for_update
-        self.engine = engine
+    def _derived_engine(self):
+        return self.column.table.engine
     def _set_parent(self, column):
         self.column = column
-        if self.engine is None:
-            self.engine = column.table.engine
         if self.for_update:
             self.column.onupdate = self
         else:
@@ -509,7 +532,7 @@ class Sequence(DefaultGenerator):
         return visitor.visit_sequence(self)
 
 
-class Index(SchemaItem):
+class Index(SchemaItem, EngineMixin):
     """Represents an index of columns from a database table
     """
     def __init__(self, name, *columns, **kw):
@@ -530,7 +553,8 @@ class Index(SchemaItem):
         self.unique = kw.pop('unique', False)
         self._init_items(*columns)
 
-    engine = property(lambda s:s.table.engine)
+    def _derived_engine(self):
+        return self.table.engine
     def _init_items(self, *args):
         for column in args:
             self.append_column(column)
@@ -570,18 +594,21 @@ class Index(SchemaItem):
                                                  for c in self.columns]),
                                       (self.unique and ', unique=True') or '')
         
-class SchemaEngine(object):
+class SchemaEngine(sql.AbstractEngine):
     """a factory object used to create implementations for schema objects.  This object
     is the ultimate base class for the engine.SQLEngine class."""
 
     def __init__(self):
         # a dictionary that stores Table objects keyed off their name (and possibly schema name)
         self.tables = {}
-        
     def reflecttable(self, table):
         """given a table, will query the database and populate its Column and ForeignKey 
         objects."""
         raise NotImplementedError()
+    def schemagenerator(self, **params):
+        raise NotImplementedError()
+    def schemadropper(self, **params):
+        raise NotImplementedError()
         
 class SchemaVisitor(sql.ClauseVisitor):
     """defines the visiting for SchemaItem objects"""
index f6e2d03c9ac0bb028097eb89eb0602d02b1b416e..2bc025e9f06e7b70d6cf5da6f8ff01d54436e02d 100644 (file)
@@ -246,6 +246,12 @@ def _is_literal(element):
 def is_column(col):
     return isinstance(col, ColumnElement)
 
+class AbstractEngine(object):
+    def execute_compiled(self, compiled, parameters, echo=None, **kwargs):
+        raise NotImplementedError()
+    def compiler(self, statement, parameters, **kwargs):
+        raise NotImplementedError()
+
 class ClauseParameters(util.OrderedDict):
     """represents a dictionary/iterator of bind parameter key names/values.  Includes parameters compiled with a Compiled object as well as additional arguments passed to the Compiled object's get_params() method.  Parameter values will be converted as per the TypeEngine objects present in the bind parameter objects.  The non-converted value can be retrieved via the get_original method.  For Compiled objects that compile positional parameters, the values() iteration of the object will return the parameter values in the correct order."""
     def __init__(self, engine=None):
@@ -340,8 +346,11 @@ class Compiled(ClauseVisitor):
         """executes this compiled object using the underlying SQLEngine"""
         if len(multiparams):
             params = multiparams
-            
-        return self.engine.execute_compiled(self, params)
+        
+        e = self.engine
+        if e is None:
+            raise InvalidRequestError("This Compiled object is not bound to any engine.")
+        return e.execute_compiled(self, params)
 
     def scalar(self, *multiparams, **params):
         """executes this compiled object via the execute() method, then 
@@ -356,7 +365,26 @@ class Compiled(ClauseVisitor):
             return row[0]
         else:
             return None
-        
+
+class Executor(object):
+    """handles the compilation/execution of a ClauseElement within the context of a particular AbtractEngine.  This 
+    AbstractEngine will usually be a SQLEngine or ConnectionProxy."""
+    def __init__(self, clauseelement, abstractengine=None):
+        self.engine=abstractengine
+        self.clauseelement = clauseelement
+    def execute(self, *multiparams, **params):
+        return self.compile(*multiparams, **params).execute(*multiparams, **params)
+    def scalar(self, *multiparams, **params):
+        return self.compile(*multiparams, **params).scalar(*multiparams, **params)
+    def compile(self, *multiparams, **params):
+        if len(multiparams):
+            bindparams = multiparams[0]
+        else:
+            bindparams = params
+        compiler = self.engine.compiler(self.clauseelement, bindparams)
+        compiler.compile()
+        return compiler
+            
 class ClauseElement(object):
     """base class for elements of a programmatically constructed SQL expression."""
     def _get_from_objects(self):
@@ -415,10 +443,12 @@ class ClauseElement(object):
             
     engine = property(lambda s: s._find_engine(), doc="attempts to locate a SQLEngine within this ClauseElement structure, or returns None if none found.")
 
-
+    def using(self, abstractengine):
+        return Executor(self, abstractengine)
+        
     def compile(self, engine = None, parameters = None, typemap=None, compiler=None):
         """compiles this SQL expression using its underlying SQLEngine to produce
-        a Compiled object.  If no engine can be found, an ansisql engine is used.
+        a Compiled object.  If no engine can be found, an ANSICompiler is used with no engine.
         bindparams is a dictionary representing the default bind parameters to be used with 
         the statement.  """
         
@@ -430,7 +460,7 @@ class ClauseElement(object):
                 
         if compiler is None:
             import sqlalchemy.ansisql as ansisql
-            compiler = ansisql.ANSICompiler(self, parameters=parameters, typemap=typemap)
+            compiler = ansisql.ANSICompiler(self, parameters=parameters)
         compiler.compile()
         return compiler
 
@@ -438,30 +468,10 @@ class ClauseElement(object):
         return str(self.compile())
         
     def execute(self, *multiparams, **params):
-        """compiles and executes this SQL expression using its underlying SQLEngine. the
-        given **params are used as bind parameters when compiling and executing the
-        expression. the DBAPI cursor object is returned."""
-        e = self.engine
-        if len(multiparams):
-            bindparams = multiparams[0]
-        else:
-            bindparams = params
-        c = self.compile(e, parameters=bindparams)
-        return c.execute(*multiparams, **params)
+        return self.using(self.engine).execute(*multiparams, **params)
 
     def scalar(self, *multiparams, **params):
-        """executes this SQL expression via the execute() method, then 
-        returns the first column of the first row.  Useful for executing functions,
-        sequences, rowcounts, etc."""
-        # we are still going off the assumption that fetching only the first row
-        # in a result set is not performance-wise any different than specifying limit=1
-        # else we'd have to construct a copy of the select() object with the limit
-        # installed (else if we change the existing select, not threadsafe)
-        row = self.execute(*multiparams, **params).fetchone()
-        if row is not None:
-            return row[0]
-        else:
-            return None
+        return self.using(self.engine).scalar(*multiparams, **params)
 
     def __and__(self, other):
         return and_(self, other)