class PGDefaultRunner(base.DefaultRunner):
def get_column_default(self, column, isinsert=True):
if column.primary_key:
- # passive defaults on primary keys have to be overridden
+ # pre-execute passive defaults on primary keys
if isinstance(column.default, schema.PassiveDefault):
- return self.connection.execute("select %s" % column.default.arg).scalar()
+ return self.execute_string("select %s" % column.default.arg)
elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
sch = column.table.schema
# TODO: this has to build into the Sequence object so we can get the quoting
exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name)
else:
exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
- return self.connection.execute(exc).scalar()
+ return self.execute_string(exc.encode(self.dialect.encoding))
return super(PGDefaultRunner, self).get_column_default(column)
def visit_sequence(self, seq):
if not seq.optional:
- return self.connection.execute("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq)).scalar()
+ return self.execute_string(("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq)).encode(self.dialect.encoding))
else:
return None
return self.__engine.dialect.create_execution_context(connection=self, **kwargs)
def __execute_raw(self, context):
- if self.__engine._should_log_info:
- self.__engine.logger.info(context.statement)
- self.__engine.logger.info(repr(context.parameters))
if context.parameters is not None and isinstance(context.parameters, list) and len(context.parameters) > 0 and isinstance(context.parameters[0], (list, tuple, dict)):
- self.__executemany(context)
+ self._cursor_executemany(context.cursor, context.statement, context.parameters, context=context)
else:
- self.__execute(context)
+ if context.parameters is None:
+ if context.dialect.positional:
+ parameters = ()
+ else:
+ parameters = {}
+ else:
+ parameters = context.parameters
+ self._cursor_execute(context.cursor, context.statement, parameters, context=context)
self._autocommit(context)
- def __execute(self, context):
- if context.parameters is None:
- if context.dialect.positional:
- context.parameters = ()
- else:
- context.parameters = {}
+ def _cursor_execute(self, cursor, statement, parameters, context=None):
+ if self.__engine._should_log_info:
+ self.__engine.logger.info(statement)
+ self.__engine.logger.info(repr(parameters))
try:
- context.dialect.do_execute(context.cursor, context.statement, context.parameters, context=context)
+ self.dialect.do_execute(cursor, statement, parameters, context=context)
except Exception, e:
if self.dialect.is_disconnect(e):
self.__connection.invalidate(e=e)
self.engine.dispose()
- context.cursor.close()
+ cursor.close()
self._autorollback()
if self.__close_with_result:
self.close()
- raise exceptions.DBAPIError.instance(context.statement, context.parameters, e)
+ raise exceptions.DBAPIError.instance(statement, parameters, e)
- def __executemany(self, context):
+ def _cursor_executemany(self, cursor, statement, parameters, context=None):
+ if self.__engine._should_log_info:
+ self.__engine.logger.info(statement)
+ self.__engine.logger.info(repr(parameters))
try:
- context.dialect.do_executemany(context.cursor, context.statement, context.parameters, context=context)
+ self.dialect.do_executemany(cursor, statement, parameters, context=context)
except Exception, e:
if self.dialect.is_disconnect(e):
self.__connection.invalidate(e=e)
self.engine.dispose()
- context.cursor.close()
+ cursor.close()
self._autorollback()
if self.__close_with_result:
self.close()
- raise exceptions.DBAPIError.instance(context.statement, context.parameters, e)
+ raise exceptions.DBAPIError.instance(statement, parameters, e)
# poor man's multimethod/generic function thingy
executors = {
def __init__(self, context):
self.context = context
- self.connection = context._connection._branch()
self.dialect = context.dialect
def get_column_default(self, column):
return None
def exec_default_sql(self, default):
- c = expression.select([default.arg]).compile(bind=self.connection)
- return self.connection._execute_compiled(c).scalar()
-
+ conn = self.context.connection
+ c = expression.select([default.arg]).compile(bind=conn)
+ return conn._execute_compiled(c).scalar()
+
+ def execute_string(self, stmt, params=None):
+ """execute a string statement, using the raw cursor,
+ and return a scalar result."""
+ conn = self.context._connection
+ conn._cursor_execute(self.context.cursor, stmt, params)
+ return self.context.cursor.fetchone()[0]
+
def visit_column_onupdate(self, onupdate):
if isinstance(onupdate.arg, expression.ClauseElement):
return self.exec_default_sql(onupdate)