]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
current progress with exec branch
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 29 Mar 2007 23:57:22 +0000 (23:57 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 29 Mar 2007 23:57:22 +0000 (23:57 +0000)
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/strategies.py
lib/sqlalchemy/logging.py
lib/sqlalchemy/pool.py
test/testbase.py

index 43d570070f4d215c6b75cfad1c7dc703f4d1ed76..e34f4a5ccbc23bb66a66b69cd0dba1dee91c4172 100644 (file)
@@ -4,18 +4,13 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-import datetime, sys, StringIO, string, types, re
+import datetime, string, types, re, random
 
-import sqlalchemy.util as util
-import sqlalchemy.sql as sql
-import sqlalchemy.engine as engine
+from sqlalchemy import util, sql, engine, schema, ansisql, exceptions
 import sqlalchemy.engine.default as default
-import sqlalchemy.schema as schema
-import sqlalchemy.ansisql as ansisql
 import sqlalchemy.types as sqltypes
 import sqlalchemy.exceptions as exceptions
 from sqlalchemy.databases import information_schema as ischema
-import re
 
 try:
     import mx.DateTime.DateTime as mxDateTime
@@ -272,7 +267,9 @@ class PGDialect(ansisql.ANSIDialect):
         if self.server_side_cursors:
             # use server-side cursors:
             # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
-            return connection.cursor('x')
+            ident = "c" + hex(random.randint(0, 65535))[2:]
+            print "IDENT:", ident
+            return connection.cursor(ident)
         else:
             return connection.cursor()
 
index cf0d350358080c4670846c52a1cbcb280d7918aa..c154a1d680e1e71f0da02f7c669482e3a7b17168 100644 (file)
@@ -255,6 +255,9 @@ class Dialect(sql.AbstractDialect):
 class ExecutionContext(object):
     """A messenger object for a Dialect that corresponds to a single execution.
 
+    ExecutionContext should have a datamember "cursor" which is created
+    at initialization time.
+    
     The Dialect should provide an ExecutionContext via the
     create_execution_context() method.  The `pre_exec` and `post_exec`
     methods will be called for compiled statements, afterwhich it is
@@ -263,7 +266,7 @@ class ExecutionContext(object):
     applicable.
     """
 
-    def pre_exec(self, engine, proxy, compiled, parameters):
+    def pre_exec(self):
         """Called before an execution of a compiled statement.
 
         `proxy` is a callable that takes a string statement and a bind
@@ -272,7 +275,7 @@ class ExecutionContext(object):
 
         raise NotImplementedError()
 
-    def post_exec(self, engine, proxy, compiled, parameters):
+    def post_exec(self):
         """Called after the execution of a compiled statement.
 
         `proxy` is a callable that takes a string statement and a bind
@@ -281,7 +284,11 @@ class ExecutionContext(object):
 
         raise NotImplementedError()
 
-    def get_rowcount(self, cursor):
+    def get_result_proxy(self):
+        """return a ResultProxy corresponding to this ExecutionContext."""
+        raise NotImplementedError()
+        
+    def get_rowcount(self):
         """Return the count of rows updated/deleted for an UPDATE/DELETE statement."""
 
         raise NotImplementedError()
@@ -497,68 +504,32 @@ class Connection(Connectable):
         """Execute a sql.Compiled object."""
         if not compiled.can_execute:
             raise exceptions.ArgumentError("Not an executeable clause: %s" % (str(compiled)))
-        cursor = self.__engine.dialect.create_cursor(self.connection)
         parameters = [compiled.construct_params(m) for m in self._params_to_listofdicts(*multiparams, **params)]
         if len(parameters) == 1:
             parameters = parameters[0]
-        def proxy(statement=None, parameters=None):
-            if statement is None:
-                return cursor
-
-            parameters = self.__engine.dialect.convert_compiled_params(parameters)
-            self._execute_raw(statement, parameters, cursor=cursor, context=context)
-            return cursor
-        context = self.__engine.dialect.create_execution_context()
-        context.pre_exec(self.__engine, proxy, compiled, parameters)
-        proxy(unicode(compiled), parameters)
-        context.post_exec(self.__engine, proxy, compiled, parameters)
-        rpargs = self.__engine.dialect.create_result_proxy_args(self, cursor)
-        return ResultProxy(self.__engine, self, cursor, context, typemap=compiled.typemap, column_labels=compiled.column_labels, **rpargs)
-
-    # poor man's multimethod/generic function thingy
-    executors = {
-        sql._Function : execute_function,
-        sql.ClauseElement : execute_clauseelement,
-        sql.ClauseVisitor : execute_compiled,
-        schema.SchemaItem:execute_default,
-        str.__mro__[-2] : execute_text
-    }
-
-    def create(self, entity, **kwargs):
-        """Create a table or index given an appropriate schema object."""
-
-        return self.__engine.create(entity, connection=self, **kwargs)
-
-    def drop(self, entity, **kwargs):
-        """Drop a table or index given an appropriate schema object."""
-
-        return self.__engine.drop(entity, connection=self, **kwargs)
-
-    def reflecttable(self, table, **kwargs):
-        """Reflect the columns in the given table from the database."""
-
-        return self.__engine.reflecttable(table, connection=self, **kwargs)
-
-    def default_schema_name(self):
-        return self.__engine.dialect.get_default_schema_name(self)
-
-    def run_callable(self, callable_):
-        return callable_(self)
-
-    def _execute_raw(self, statement, parameters=None, cursor=None, context=None, **kwargs):
-        if cursor is None:
-            cursor = self.__engine.dialect.create_cursor(self.connection)
+        context = self.__engine.dialect.create_execution_context(compiled=compiled, parameters=parameters, connection=self, engine=self.__engine)
+        context.pre_exec()
+        self.execute_compiled_impl(compiled, parameters, context)
+        context.post_exec()
+        return context.get_result_proxy()
+    
+    def _execute_compiled_impl(self, compiled, parameters, context):
+        self._execute_raw(unicode(compiled), self.__engine.dialect.convert_compiled_params(parameters), context=context)
+            
+    def _execute_raw(self, statement, parameters=None, context=None, **kwargs):
         if not self.__engine.dialect.supports_unicode_statements():
             # encode to ascii, with full error handling
             statement = statement.encode('ascii')
+        if context is None:
+            context = self.__engine.dialect.create_execution_context(statement=statement, parameters=parameters, connection=self, engine=self.__engine)
         self.__engine.logger.info(statement)
         self.__engine.logger.info(repr(parameters))
         if parameters is not None and isinstance(parameters, list) and len(parameters) > 0 and (isinstance(parameters[0], list) or isinstance(parameters[0], dict)):
-            self._executemany(cursor, statement, parameters, context=context)
+            self._executemany(context.cursor, statement, parameters, context=context)
         else:
-            self._execute(cursor, statement, parameters, context=context)
+            self._execute(context.cursor, statement, parameters, context=context)
         self._autocommit(statement)
-        return cursor
+        return context.cursor
 
     def _execute(self, c, statement, parameters, context=None):
         if parameters is None:
@@ -585,6 +556,40 @@ class Connection(Connectable):
                 self.close()
             raise exceptions.SQLError(statement, parameters, e)
 
+
+
+
+    # poor man's multimethod/generic function thingy
+    executors = {
+        sql._Function : execute_function,
+        sql.ClauseElement : execute_clauseelement,
+        sql.ClauseVisitor : execute_compiled,
+        schema.SchemaItem:execute_default,
+        str.__mro__[-2] : execute_text
+    }
+
+    def create(self, entity, **kwargs):
+        """Create a table or index given an appropriate schema object."""
+
+        return self.__engine.create(entity, connection=self, **kwargs)
+
+    def drop(self, entity, **kwargs):
+        """Drop a table or index given an appropriate schema object."""
+
+        return self.__engine.drop(entity, connection=self, **kwargs)
+
+    def reflecttable(self, table, **kwargs):
+        """Reflect the columns in the given table from the database."""
+
+        return self.__engine.reflecttable(table, connection=self, **kwargs)
+
+    def default_schema_name(self):
+        return self.__engine.dialect.get_default_schema_name(self)
+
+    def run_callable(self, callable_):
+        return callable_(self)
+
+
     def proxy(self, statement=None, parameters=None):
         """Execute the given statement string and parameter object.
 
index 86563cd7cbcedfb5d4c6a876b9f57003bd5dba15..bcd7a6c36b3ad5cc36ba0b65b15eb3a013d2dae0 100644 (file)
@@ -157,15 +157,35 @@ class DefaultDialect(base.Dialect):
     ischema = property(_get_ischema, doc="""returns an ISchema object for this engine, which allows access to information_schema tables (if supported)""")
 
 class DefaultExecutionContext(base.ExecutionContext):
-    def __init__(self, dialect):
+    def __init__(self, dialect, engine, connection, compiled=None, parameters=None, statement=None):
         self.dialect = dialect
+        self.engine = engine
+        self.connection = connection
+        self.compiled = compiled
+        self.parameters = parameters
+        self.statement = statement
+        if compiled is not None:
+            self.typemap = compiled.typemap
+            self.column_labels = compiled.column_labels
+        else:
+            self.typemap = self.column_labels = None
+        self.cursor = self.dialect.create_cursor(self.connection.connection)
+
+    def proxy(self, statement=None, parameters=None):
+        if statement is not None:
+            self.connection._execute_compiled_impl(compiled, parameters, self)
+        return self.cursor
 
-    def pre_exec(self, engine, proxy, compiled, parameters):
-        self._process_defaults(engine, proxy, compiled, parameters)
+    def pre_exec(self):
+        if self.compiled is not None:
+            self._process_defaults()
 
-    def post_exec(self, engine, proxy, compiled, parameters):
+    def post_exec(self):
         pass
 
+    def get_result_proxy(self):
+        return base.ResultProxy(self.engine, self.connection, self.cursor, self, typemap=self.typemap, column_labels=self.column_labels)
+
     def get_rowcount(self, cursor):
         if hasattr(self, '_rowcount'):
             return self._rowcount
@@ -187,16 +207,16 @@ class DefaultExecutionContext(base.ExecutionContext):
     def lastrow_has_defaults(self):
         return self._lastrow_has_defaults
 
-    def set_input_sizes(self, cursor, parameters):
+    def set_input_sizes(self):
         """Given a cursor and ClauseParameters, call the appropriate
         style of ``setinputsizes()`` on the cursor, using DBAPI types
         from the bind parameter's ``TypeEngine`` objects.
         """
 
-        if isinstance(parameters, list):
-            plist = parameters
+        if isinstance(self.parameters, list):
+            plist = self.parameters
         else:
-            plist = [parameters]
+            plist = [self.parameters]
         if self.dialect.positional:
             inputsizes = []
             for params in plist[0:1]:
@@ -205,7 +225,7 @@ class DefaultExecutionContext(base.ExecutionContext):
                     dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.module)
                     if dbtype is not None:
                         inputsizes.append(dbtype)
-            cursor.setinputsizes(*inputsizes)
+            self.cursor.setinputsizes(*inputsizes)
         else:
             inputsizes = {}
             for params in plist[0:1]:
@@ -214,9 +234,9 @@ class DefaultExecutionContext(base.ExecutionContext):
                     dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.module)
                     if dbtype is not None:
                         inputsizes[key] = dbtype
-            cursor.setinputsizes(**inputsizes)
+            self.cursor.setinputsizes(**inputsizes)
 
-    def _process_defaults(self, engine, proxy, compiled, parameters):
+    def _process_defaults(self):
         """``INSERT`` and ``UPDATE`` statements, when compiled, may
         have additional columns added to their ``VALUES`` and ``SET``
         lists corresponding to column defaults/onupdates that are
@@ -234,23 +254,21 @@ class DefaultExecutionContext(base.ExecutionContext):
         statement.
         """
 
-        if compiled is None: return
-
-        if getattr(compiled, "isinsert", False):
-            if isinstance(parameters, list):
-                plist = parameters
+        if getattr(self.compiled, "isinsert", False):
+            if isinstance(self.parameters, list):
+                plist = self.parameters
             else:
-                plist = [parameters]
-            drunner = self.dialect.defaultrunner(engine, proxy)
+                plist = [self.parameters]
+            drunner = self.dialect.defaultrunner(self.engine, self.proxy)
             self._lastrow_has_defaults = False
             for param in plist:
                 last_inserted_ids = []
                 need_lastrowid=False
                 # check the "default" status of each column in the table
-                for c in compiled.statement.table.c:
+                for c in self.compiled.statement.table.c:
                     # check if it will be populated by a SQL clause - we'll need that
                     # after execution.
-                    if c in compiled.inline_params:
+                    if c in self.compiled.inline_params:
                         self._lastrow_has_defaults = True
                         if c.primary_key:
                             need_lastrowid = True
@@ -278,19 +296,19 @@ class DefaultExecutionContext(base.ExecutionContext):
                 else:
                     self._last_inserted_ids = last_inserted_ids
                 self._last_inserted_params = param
-        elif getattr(compiled, 'isupdate', False):
-            if isinstance(parameters, list):
-                plist = parameters
+        elif getattr(self.compiled, 'isupdate', False):
+            if isinstance(self.parameters, list):
+                plist = self.parameters
             else:
-                plist = [parameters]
-            drunner = self.dialect.defaultrunner(engine, proxy)
+                plist = [self.parameters]
+            drunner = self.dialect.defaultrunner(self.engine, self.proxy)
             self._lastrow_has_defaults = False
             for param in plist:
                 # check the "onupdate" status of each column in the table
-                for c in compiled.statement.table.c:
+                for c in self.compiled.statement.table.c:
                     # it will be populated by a SQL clause - we'll need that
                     # after execution.
-                    if c in compiled.inline_params:
+                    if c in self.compiled.inline_params:
                         pass
                     # its not in the bind parameters, and theres an "onupdate" defined for the column;
                     # execute it and add to bind params
index 7a7b84aa99af49dfd22c5e8aca502f20fcec37dd..af860d557cea946dbb0ae5382cb6c1c14d4696b2 100644 (file)
@@ -73,6 +73,9 @@ class DefaultEngineStrategy(EngineStrategy):
 
             poolclass = kwargs.pop('poolclass', getattr(module, 'poolclass', poollib.QueuePool))
             pool_args = {}
+
+            pool_args['cursor_creator'] = dialect.create_cursor
+            
             # consume pool arguments from kwargs, translating a few of the arguments
             for k in util.get_cls_kwargs(poolclass):
                 tk = {'echo':'echo_pool', 'timeout':'pool_timeout', 'recycle':'pool_recycle'}.get(k, k)
index 6f4368707988ee41aab02ab92c45ecd34827bebd..91326233a6c57ed478cc2e0d34acaed88e806646 100644 (file)
@@ -31,8 +31,8 @@ import sys
 # py2.5 absolute imports will fix....
 logging = __import__('logging')
 
-# turn off logging at the root sqlalchemy level
-logging.getLogger('sqlalchemy').setLevel(logging.ERROR)
+
+logging.getLogger('sqlalchemy').setLevel(logging.WARN)
 
 default_enabled = False
 def default_logging(name):
index 787fd059f288bb31cd2bcdaf678705a18939c70d..d65b28b5576da218f7c86dd2b5eadf705c2801d2 100644 (file)
@@ -237,7 +237,9 @@ class _ConnectionFairy(object):
             raise
         if self.__pool.echo:
             self.__pool.log("Connection %s checked out from pool" % repr(self.connection))
-
+    
+    _logger = property(lambda self: self.__pool.logger)
+         
     def invalidate(self):
         if self.connection is None:
             raise exceptions.InvalidRequestError("This connection is closed")
@@ -311,7 +313,10 @@ class _CursorFairy(object):
     def close(self):
         if self in self.__parent._cursors:
             del self.__parent._cursors[self]
-            self.cursor.close()
+            try:
+                self.cursor.close()
+            except Exception, e:
+                self.__parent._logger.warn("Error closing cursor: " + str(e))
 
     def __getattr__(self, key):
         return getattr(self.cursor, key)
index 8a1d9ee59a827b7ebbf04ffaf4b26265f2376c2f..c02b36b5284d14749d5646dc5c3d6f94a960fdc0 100644 (file)
@@ -49,6 +49,7 @@ def parse_argv():
     parser.add_option("--enginestrategy", action="store", default=None, dest="enginestrategy", help="engine strategy (plain or threadlocal, defaults to plain)")
     parser.add_option("--coverage", action="store_true", dest="coverage", help="Dump a full coverage report after running")
     parser.add_option("--reversetop", action="store_true", dest="topological", help="Reverse the collection ordering for topological sorts (helps reveal dependency issues)")
+    parser.add_option("--serverside", action="store_true", dest="serverside", help="Turn on server side cursors for PG")
     
     (options, args) = parser.parse_args()
     sys.argv[1:] = args
@@ -73,7 +74,7 @@ def parse_argv():
             db_uri = 'oracle://scott:tiger@127.0.0.1:1521'
         elif DBTYPE == 'oracle8':
             db_uri = 'oracle://scott:tiger@127.0.0.1:1521'
-            opts = {'use_ansi':False}
+            opts['use_ansi'] = False
         elif DBTYPE == 'mssql':
             db_uri = 'mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test'
         elif DBTYPE == 'firebird':
@@ -94,6 +95,9 @@ def parse_argv():
     
     global with_coverage
     with_coverage = options.coverage
+
+    if options.serverside:
+        opts['server_side_cursors'] = True
     
     if options.enginestrategy is not None:
         opts['strategy'] = options.enginestrategy