From 6b7bd8fb1575ce47d221e8c9a9cc579abc271c92 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 29 Mar 2007 23:57:22 +0000 Subject: [PATCH] current progress with exec branch --- lib/sqlalchemy/databases/postgres.py | 13 ++- lib/sqlalchemy/engine/base.py | 113 ++++++++++++++------------- lib/sqlalchemy/engine/default.py | 72 ++++++++++------- lib/sqlalchemy/engine/strategies.py | 3 + lib/sqlalchemy/logging.py | 4 +- lib/sqlalchemy/pool.py | 9 ++- test/testbase.py | 6 +- 7 files changed, 126 insertions(+), 94 deletions(-) diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 43d570070f..e34f4a5ccb 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -4,18 +4,13 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import datetime, sys, StringIO, string, types, re +import datetime, string, types, re, random -import sqlalchemy.util as util -import sqlalchemy.sql as sql -import sqlalchemy.engine as engine +from sqlalchemy import util, sql, engine, schema, ansisql, exceptions import sqlalchemy.engine.default as default -import sqlalchemy.schema as schema -import sqlalchemy.ansisql as ansisql import sqlalchemy.types as sqltypes import sqlalchemy.exceptions as exceptions from sqlalchemy.databases import information_schema as ischema -import re try: import mx.DateTime.DateTime as mxDateTime @@ -272,7 +267,9 @@ class PGDialect(ansisql.ANSIDialect): if self.server_side_cursors: # use server-side cursors: # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html - return connection.cursor('x') + ident = "c" + hex(random.randint(0, 65535))[2:] + print "IDENT:", ident + return connection.cursor(ident) else: return connection.cursor() diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index cf0d350358..c154a1d680 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -255,6 +255,9 @@ class Dialect(sql.AbstractDialect): class ExecutionContext(object): """A messenger object for a Dialect that corresponds to a single execution. + ExecutionContext should have a datamember "cursor" which is created + at initialization time. + The Dialect should provide an ExecutionContext via the create_execution_context() method. The `pre_exec` and `post_exec` methods will be called for compiled statements, afterwhich it is @@ -263,7 +266,7 @@ class ExecutionContext(object): applicable. """ - def pre_exec(self, engine, proxy, compiled, parameters): + def pre_exec(self): """Called before an execution of a compiled statement. `proxy` is a callable that takes a string statement and a bind @@ -272,7 +275,7 @@ class ExecutionContext(object): raise NotImplementedError() - def post_exec(self, engine, proxy, compiled, parameters): + def post_exec(self): """Called after the execution of a compiled statement. `proxy` is a callable that takes a string statement and a bind @@ -281,7 +284,11 @@ class ExecutionContext(object): raise NotImplementedError() - def get_rowcount(self, cursor): + def get_result_proxy(self): + """return a ResultProxy corresponding to this ExecutionContext.""" + raise NotImplementedError() + + def get_rowcount(self): """Return the count of rows updated/deleted for an UPDATE/DELETE statement.""" raise NotImplementedError() @@ -497,68 +504,32 @@ class Connection(Connectable): """Execute a sql.Compiled object.""" if not compiled.can_execute: raise exceptions.ArgumentError("Not an executeable clause: %s" % (str(compiled))) - cursor = self.__engine.dialect.create_cursor(self.connection) parameters = [compiled.construct_params(m) for m in self._params_to_listofdicts(*multiparams, **params)] if len(parameters) == 1: parameters = parameters[0] - def proxy(statement=None, parameters=None): - if statement is None: - return cursor - - parameters = self.__engine.dialect.convert_compiled_params(parameters) - self._execute_raw(statement, parameters, cursor=cursor, context=context) - return cursor - context = self.__engine.dialect.create_execution_context() - context.pre_exec(self.__engine, proxy, compiled, parameters) - proxy(unicode(compiled), parameters) - context.post_exec(self.__engine, proxy, compiled, parameters) - rpargs = self.__engine.dialect.create_result_proxy_args(self, cursor) - return ResultProxy(self.__engine, self, cursor, context, typemap=compiled.typemap, column_labels=compiled.column_labels, **rpargs) - - # poor man's multimethod/generic function thingy - executors = { - sql._Function : execute_function, - sql.ClauseElement : execute_clauseelement, - sql.ClauseVisitor : execute_compiled, - schema.SchemaItem:execute_default, - str.__mro__[-2] : execute_text - } - - def create(self, entity, **kwargs): - """Create a table or index given an appropriate schema object.""" - - return self.__engine.create(entity, connection=self, **kwargs) - - def drop(self, entity, **kwargs): - """Drop a table or index given an appropriate schema object.""" - - return self.__engine.drop(entity, connection=self, **kwargs) - - def reflecttable(self, table, **kwargs): - """Reflect the columns in the given table from the database.""" - - return self.__engine.reflecttable(table, connection=self, **kwargs) - - def default_schema_name(self): - return self.__engine.dialect.get_default_schema_name(self) - - def run_callable(self, callable_): - return callable_(self) - - def _execute_raw(self, statement, parameters=None, cursor=None, context=None, **kwargs): - if cursor is None: - cursor = self.__engine.dialect.create_cursor(self.connection) + context = self.__engine.dialect.create_execution_context(compiled=compiled, parameters=parameters, connection=self, engine=self.__engine) + context.pre_exec() + self.execute_compiled_impl(compiled, parameters, context) + context.post_exec() + return context.get_result_proxy() + + def _execute_compiled_impl(self, compiled, parameters, context): + self._execute_raw(unicode(compiled), self.__engine.dialect.convert_compiled_params(parameters), context=context) + + def _execute_raw(self, statement, parameters=None, context=None, **kwargs): if not self.__engine.dialect.supports_unicode_statements(): # encode to ascii, with full error handling statement = statement.encode('ascii') + if context is None: + context = self.__engine.dialect.create_execution_context(statement=statement, parameters=parameters, connection=self, engine=self.__engine) self.__engine.logger.info(statement) self.__engine.logger.info(repr(parameters)) if parameters is not None and isinstance(parameters, list) and len(parameters) > 0 and (isinstance(parameters[0], list) or isinstance(parameters[0], dict)): - self._executemany(cursor, statement, parameters, context=context) + self._executemany(context.cursor, statement, parameters, context=context) else: - self._execute(cursor, statement, parameters, context=context) + self._execute(context.cursor, statement, parameters, context=context) self._autocommit(statement) - return cursor + return context.cursor def _execute(self, c, statement, parameters, context=None): if parameters is None: @@ -585,6 +556,40 @@ class Connection(Connectable): self.close() raise exceptions.SQLError(statement, parameters, e) + + + + # poor man's multimethod/generic function thingy + executors = { + sql._Function : execute_function, + sql.ClauseElement : execute_clauseelement, + sql.ClauseVisitor : execute_compiled, + schema.SchemaItem:execute_default, + str.__mro__[-2] : execute_text + } + + def create(self, entity, **kwargs): + """Create a table or index given an appropriate schema object.""" + + return self.__engine.create(entity, connection=self, **kwargs) + + def drop(self, entity, **kwargs): + """Drop a table or index given an appropriate schema object.""" + + return self.__engine.drop(entity, connection=self, **kwargs) + + def reflecttable(self, table, **kwargs): + """Reflect the columns in the given table from the database.""" + + return self.__engine.reflecttable(table, connection=self, **kwargs) + + def default_schema_name(self): + return self.__engine.dialect.get_default_schema_name(self) + + def run_callable(self, callable_): + return callable_(self) + + def proxy(self, statement=None, parameters=None): """Execute the given statement string and parameter object. diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 86563cd7cb..bcd7a6c36b 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -157,15 +157,35 @@ class DefaultDialect(base.Dialect): ischema = property(_get_ischema, doc="""returns an ISchema object for this engine, which allows access to information_schema tables (if supported)""") class DefaultExecutionContext(base.ExecutionContext): - def __init__(self, dialect): + def __init__(self, dialect, engine, connection, compiled=None, parameters=None, statement=None): self.dialect = dialect + self.engine = engine + self.connection = connection + self.compiled = compiled + self.parameters = parameters + self.statement = statement + if compiled is not None: + self.typemap = compiled.typemap + self.column_labels = compiled.column_labels + else: + self.typemap = self.column_labels = None + self.cursor = self.dialect.create_cursor(self.connection.connection) + + def proxy(self, statement=None, parameters=None): + if statement is not None: + self.connection._execute_compiled_impl(compiled, parameters, self) + return self.cursor - def pre_exec(self, engine, proxy, compiled, parameters): - self._process_defaults(engine, proxy, compiled, parameters) + def pre_exec(self): + if self.compiled is not None: + self._process_defaults() - def post_exec(self, engine, proxy, compiled, parameters): + def post_exec(self): pass + def get_result_proxy(self): + return base.ResultProxy(self.engine, self.connection, self.cursor, self, typemap=self.typemap, column_labels=self.column_labels) + def get_rowcount(self, cursor): if hasattr(self, '_rowcount'): return self._rowcount @@ -187,16 +207,16 @@ class DefaultExecutionContext(base.ExecutionContext): def lastrow_has_defaults(self): return self._lastrow_has_defaults - def set_input_sizes(self, cursor, parameters): + def set_input_sizes(self): """Given a cursor and ClauseParameters, call the appropriate style of ``setinputsizes()`` on the cursor, using DBAPI types from the bind parameter's ``TypeEngine`` objects. """ - if isinstance(parameters, list): - plist = parameters + if isinstance(self.parameters, list): + plist = self.parameters else: - plist = [parameters] + plist = [self.parameters] if self.dialect.positional: inputsizes = [] for params in plist[0:1]: @@ -205,7 +225,7 @@ class DefaultExecutionContext(base.ExecutionContext): dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.module) if dbtype is not None: inputsizes.append(dbtype) - cursor.setinputsizes(*inputsizes) + self.cursor.setinputsizes(*inputsizes) else: inputsizes = {} for params in plist[0:1]: @@ -214,9 +234,9 @@ class DefaultExecutionContext(base.ExecutionContext): dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.module) if dbtype is not None: inputsizes[key] = dbtype - cursor.setinputsizes(**inputsizes) + self.cursor.setinputsizes(**inputsizes) - def _process_defaults(self, engine, proxy, compiled, parameters): + def _process_defaults(self): """``INSERT`` and ``UPDATE`` statements, when compiled, may have additional columns added to their ``VALUES`` and ``SET`` lists corresponding to column defaults/onupdates that are @@ -234,23 +254,21 @@ class DefaultExecutionContext(base.ExecutionContext): statement. """ - if compiled is None: return - - if getattr(compiled, "isinsert", False): - if isinstance(parameters, list): - plist = parameters + if getattr(self.compiled, "isinsert", False): + if isinstance(self.parameters, list): + plist = self.parameters else: - plist = [parameters] - drunner = self.dialect.defaultrunner(engine, proxy) + plist = [self.parameters] + drunner = self.dialect.defaultrunner(self.engine, self.proxy) self._lastrow_has_defaults = False for param in plist: last_inserted_ids = [] need_lastrowid=False # check the "default" status of each column in the table - for c in compiled.statement.table.c: + for c in self.compiled.statement.table.c: # check if it will be populated by a SQL clause - we'll need that # after execution. - if c in compiled.inline_params: + if c in self.compiled.inline_params: self._lastrow_has_defaults = True if c.primary_key: need_lastrowid = True @@ -278,19 +296,19 @@ class DefaultExecutionContext(base.ExecutionContext): else: self._last_inserted_ids = last_inserted_ids self._last_inserted_params = param - elif getattr(compiled, 'isupdate', False): - if isinstance(parameters, list): - plist = parameters + elif getattr(self.compiled, 'isupdate', False): + if isinstance(self.parameters, list): + plist = self.parameters else: - plist = [parameters] - drunner = self.dialect.defaultrunner(engine, proxy) + plist = [self.parameters] + drunner = self.dialect.defaultrunner(self.engine, self.proxy) self._lastrow_has_defaults = False for param in plist: # check the "onupdate" status of each column in the table - for c in compiled.statement.table.c: + for c in self.compiled.statement.table.c: # it will be populated by a SQL clause - we'll need that # after execution. - if c in compiled.inline_params: + if c in self.compiled.inline_params: pass # its not in the bind parameters, and theres an "onupdate" defined for the column; # execute it and add to bind params diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index 7a7b84aa99..af860d557c 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -73,6 +73,9 @@ class DefaultEngineStrategy(EngineStrategy): poolclass = kwargs.pop('poolclass', getattr(module, 'poolclass', poollib.QueuePool)) pool_args = {} + + pool_args['cursor_creator'] = dialect.create_cursor + # consume pool arguments from kwargs, translating a few of the arguments for k in util.get_cls_kwargs(poolclass): tk = {'echo':'echo_pool', 'timeout':'pool_timeout', 'recycle':'pool_recycle'}.get(k, k) diff --git a/lib/sqlalchemy/logging.py b/lib/sqlalchemy/logging.py index 6f43687079..91326233a6 100644 --- a/lib/sqlalchemy/logging.py +++ b/lib/sqlalchemy/logging.py @@ -31,8 +31,8 @@ import sys # py2.5 absolute imports will fix.... logging = __import__('logging') -# turn off logging at the root sqlalchemy level -logging.getLogger('sqlalchemy').setLevel(logging.ERROR) + +logging.getLogger('sqlalchemy').setLevel(logging.WARN) default_enabled = False def default_logging(name): diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py index 787fd059f2..d65b28b557 100644 --- a/lib/sqlalchemy/pool.py +++ b/lib/sqlalchemy/pool.py @@ -237,7 +237,9 @@ class _ConnectionFairy(object): raise if self.__pool.echo: self.__pool.log("Connection %s checked out from pool" % repr(self.connection)) - + + _logger = property(lambda self: self.__pool.logger) + def invalidate(self): if self.connection is None: raise exceptions.InvalidRequestError("This connection is closed") @@ -311,7 +313,10 @@ class _CursorFairy(object): def close(self): if self in self.__parent._cursors: del self.__parent._cursors[self] - self.cursor.close() + try: + self.cursor.close() + except Exception, e: + self.__parent._logger.warn("Error closing cursor: " + str(e)) def __getattr__(self, key): return getattr(self.cursor, key) diff --git a/test/testbase.py b/test/testbase.py index 8a1d9ee59a..c02b36b528 100644 --- a/test/testbase.py +++ b/test/testbase.py @@ -49,6 +49,7 @@ def parse_argv(): parser.add_option("--enginestrategy", action="store", default=None, dest="enginestrategy", help="engine strategy (plain or threadlocal, defaults to plain)") parser.add_option("--coverage", action="store_true", dest="coverage", help="Dump a full coverage report after running") parser.add_option("--reversetop", action="store_true", dest="topological", help="Reverse the collection ordering for topological sorts (helps reveal dependency issues)") + parser.add_option("--serverside", action="store_true", dest="serverside", help="Turn on server side cursors for PG") (options, args) = parser.parse_args() sys.argv[1:] = args @@ -73,7 +74,7 @@ def parse_argv(): db_uri = 'oracle://scott:tiger@127.0.0.1:1521' elif DBTYPE == 'oracle8': db_uri = 'oracle://scott:tiger@127.0.0.1:1521' - opts = {'use_ansi':False} + opts['use_ansi'] = False elif DBTYPE == 'mssql': db_uri = 'mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test' elif DBTYPE == 'firebird': @@ -94,6 +95,9 @@ def parse_argv(): global with_coverage with_coverage = options.coverage + + if options.serverside: + opts['server_side_cursors'] = True if options.enginestrategy is not None: opts['strategy'] = options.enginestrategy -- 2.47.2