From a76927f584dab481383592645a2a471fed37ecf9 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 28 Feb 2010 23:51:54 +0000 Subject: [PATCH] - the execution sequence pulls all rowcount/last inserted ID info from the cursor before commit() is called on the DBAPI connection in an "autocommit" scenario. This helps mxodbc with rowcount and is probably a good idea overall. - cx_oracle wants list(), not tuple(), for empty execute. - cleaned up plain SQL param handling --- CHANGES | 5 ++ lib/sqlalchemy/dialects/mssql/base.py | 6 +- lib/sqlalchemy/dialects/mssql/mxodbc.py | 14 +---- lib/sqlalchemy/dialects/oracle/cx_oracle.py | 2 + lib/sqlalchemy/engine/base.py | 55 ++++++++++++----- lib/sqlalchemy/engine/default.py | 36 ++++++----- test/engine/test_execute.py | 67 +++++++++++++++------ test/sql/test_rowcount.py | 9 +-- 8 files changed, 126 insertions(+), 68 deletions(-) diff --git a/CHANGES b/CHANGES index c4495552a3..d7695c40f1 100644 --- a/CHANGES +++ b/CHANGES @@ -190,6 +190,11 @@ CHANGES Note that it is *not* built/installed by default. See README for installation instructions. + - the execution sequence pulls all rowcount/last inserted ID + info from the cursor before commit() is called on the + DBAPI connection in an "autocommit" scenario. This helps + mxodbc with rowcount and is probably a good idea overall. + - Opened up logging a bit such that isEnabledFor() is called more often, so that changes to the log level for engine/pool will be reflected on next connect. This adds a small diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index d1ccf44e2c..254aa54fd3 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -826,7 +826,11 @@ class MSExecutionContext(default.DefaultExecutionContext): def handle_dbapi_exception(self, e): if self._enable_identity_insert: try: - self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table)) + self.cursor.execute("SET IDENTITY_INSERT %s OFF" % + self.dialect.\ + identifier_preparer.\ + format_table(self.compiled.statement.table) + ) except: pass diff --git a/lib/sqlalchemy/dialects/mssql/mxodbc.py b/lib/sqlalchemy/dialects/mssql/mxodbc.py index 73cf1346e0..85c0ac6ac4 100644 --- a/lib/sqlalchemy/dialects/mssql/mxodbc.py +++ b/lib/sqlalchemy/dialects/mssql/mxodbc.py @@ -9,19 +9,7 @@ from sqlalchemy.dialects.mssql.pyodbc import MSExecutionContext_pyodbc # The pyodbc execution context seems to work for mxODBC; reuse it here class MSExecutionContext_mxodbc(MSExecutionContext_pyodbc): - - def post_exec(self): - # snag rowcount before the cursor is closed - if not self.cursor.description: - self._rowcount = self.cursor.rowcount - super(MSExecutionContext_mxodbc, self).post_exec() - - @property - def rowcount(self): - if hasattr(self, '_rowcount'): - return self._rowcount - else: - return self.cursor.rowcount + pass class MSDialect_mxodbc(MxODBCConnector, MSDialect): diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index 854eb875a8..47909f8d17 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -307,6 +307,8 @@ class Oracle_cx_oracle(OracleDialect): driver = "cx_oracle" colspecs = colspecs + execute_sequence_format = list + def __init__(self, auto_setinputsizes=True, auto_convert_lobs=True, diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index b4f2524d6e..46907dfcf9 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -77,6 +77,10 @@ class Dialect(object): execution_ctx_cls a :class:`ExecutionContext` class used to handle statement execution + execute_sequence_format + either the 'tuple' or 'list' type, depending on what cursor.execute() + accepts for the second argument (they vary). + preparer a :class:`~sqlalchemy.sql.compiler.IdentifierPreparer` class used to quote identifiers. @@ -1055,6 +1059,7 @@ class Connection(Connectable): In the case of 'raw' execution which accepts positional parameters, it may be a list of tuples or lists. + """ if not multiparams: @@ -1104,7 +1109,9 @@ class Connection(Connectable): keys = [] context = self.__create_execution_context( - compiled_sql=elem.compile(dialect=self.dialect, column_keys=keys, inline=len(params) > 1), + compiled_sql=elem.compile( + dialect=self.dialect, column_keys=keys, + inline=len(params) > 1), parameters=params ) return self.__execute_context(context) @@ -1128,9 +1135,15 @@ class Connection(Connectable): context.pre_exec() if context.executemany: - self._cursor_executemany(context.cursor, context.statement, context.parameters, context=context) + self._cursor_executemany( + context.cursor, + context.statement, + context.parameters, context=context) else: - self._cursor_execute(context.cursor, context.statement, context.parameters[0], context=context) + self._cursor_execute( + context.cursor, + context.statement, + context.parameters[0], context=context) if context.compiled: context.post_exec() @@ -1138,10 +1151,17 @@ class Connection(Connectable): if context.isinsert and not context.executemany: context.post_insert() + # create a resultproxy, get rowcount/implicit RETURNING + # rows, close cursor if no further results pending + r = context.get_result_proxy()._autoclose() + if self.__transaction is None and context.should_autocommit: self._commit_impl() - - return context.get_result_proxy()._autoclose() + + if r.closed and self.should_close_with_result: + self.close() + + return r def _handle_dbapi_exception(self, e, statement, parameters, cursor, context): if getattr(self, '_reentrant_error', False): @@ -1893,6 +1913,7 @@ class ResultProxy(object): _process_row = RowProxy out_parameters = None + _can_close_connection = False def __init__(self, context): self.context = context @@ -1904,7 +1925,6 @@ class ResultProxy(object): context.engine._should_log_debug() self._init_metadata() - def _init_metadata(self): metadata = self._cursor_description() if metadata is None: @@ -1962,21 +1982,26 @@ class ResultProxy(object): return self.cursor.description def _autoclose(self): + """called by the Connection to autoclose cursors that have no pending results + beyond those used by an INSERT/UPDATE/DELETE with no explicit RETURNING clause. + + """ if self.context.isinsert: if self.context._is_implicit_returning: self.context._fetch_implicit_returning(self) - self.close() + self.close(_autoclose_connection=False) elif not self.context._is_explicit_returning: - self.close() + self.close(_autoclose_connection=False) elif self._metadata is None: # no results, get rowcount - # (which requires open cursor on some DB's such as firebird), + # (which requires open cursor on some drivers + # such as kintersbasdb, mxodbc), self.rowcount - self.close() # autoclose - + self.close(_autoclose_connection=False) + return self - - def close(self): + + def close(self, _autoclose_connection=True): """Close this ResultProxy. Closes the underlying DBAPI cursor corresponding to the execution. @@ -1992,12 +2017,14 @@ class ResultProxy(object): * all result rows are exhausted using the fetchXXX() methods. * cursor.description is None. + """ if not self.closed: self.closed = True self.cursor.close() - if self.connection.should_close_with_result: + if _autoclose_connection and \ + self.connection.should_close_with_result: self.connection.close() def __iter__(self): diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 4d4fd7c719..cd2c103938 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -30,6 +30,10 @@ class DefaultDialect(base.Dialect): preparer = compiler.IdentifierPreparer supports_alter = True + # most DBAPIs happy with this for execute(). + # not cx_oracle. + execute_sequence_format = tuple + supports_sequences = False sequences_optional = False preexecute_autoincrement_sequences = False @@ -365,7 +369,7 @@ class DefaultExecutionContext(base.ExecutionContext): @util.memoized_property def _default_params(self): if self.dialect.positional: - return () + return self.dialect.execute_sequence_format() else: return {} @@ -392,21 +396,23 @@ class DefaultExecutionContext(base.ExecutionContext): """Apply string encoding to the keys of dictionary-based bind parameters. This is only used executing textual, non-compiled SQL expressions. + """ - - if self.dialect.positional or self.dialect.supports_unicode_statements: - if params: + + if not params: + return [self._default_params] + elif isinstance(params[0], self.dialect.execute_sequence_format): + return params + elif isinstance(params[0], dict): + if self.dialect.supports_unicode_statements: return params else: - return [self._default_params] + def proc(d): + return dict((k.encode(self.dialect.encoding), d[k]) for k in d) + return [proc(d) for d in params] or [{}] else: - def proc(d): - # sigh, sometimes we get positional arguments with a dialect - # that doesnt specify positional (because of execute_text()) - if not isinstance(d, dict): - return d - return dict((k.encode(self.dialect.encoding), d[k]) for k in d) - return [proc(d) for d in params] or [{}] + return [self.dialect.execute_sequence_format(p) for p in params] + def __convert_compiled_params(self, compiled_parameters): """Convert the dictionary of bind parameter values into a dict or list @@ -423,7 +429,7 @@ class DefaultExecutionContext(base.ExecutionContext): param.append(processors[key](compiled_params[key])) else: param.append(compiled_params[key]) - parameters.append(tuple(param)) + parameters.append(self.dialect.execute_sequence_format(param)) else: encode = not self.dialect.supports_unicode_statements for compiled_params in compiled_parameters: @@ -442,7 +448,7 @@ class DefaultExecutionContext(base.ExecutionContext): else: param[key] = compiled_params[key] parameters.append(param) - return tuple(parameters) + return self.dialect.execute_sequence_format(parameters) def should_autocommit_text(self, statement): return AUTOCOMMIT_REGEXP.match(statement) @@ -514,7 +520,7 @@ class DefaultExecutionContext(base.ExecutionContext): def _fetch_implicit_returning(self, resultproxy): table = self.compiled.statement.table - row = resultproxy.first() + row = resultproxy.fetchone() self._inserted_primary_key = [v is not None and v or row[c] for c, v in zip(table.primary_key, self._inserted_primary_key) diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index 4a45fceb31..1752fda0dd 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -12,9 +12,13 @@ users, metadata = None, None class ExecuteTest(TestBase): @classmethod def setup_class(cls): - global users, metadata + global users, users_autoinc, metadata metadata = MetaData(testing.db) users = Table('users', metadata, + Column('user_id', INT, primary_key = True, autoincrement=False), + Column('user_name', VARCHAR(20)), + ) + users_autoinc = Table('users_autoinc', metadata, Column('user_id', INT, primary_key = True, test_needs_autoincrement=True), Column('user_name', VARCHAR(20)), ) @@ -28,16 +32,22 @@ class ExecuteTest(TestBase): def teardown_class(cls): metadata.drop_all() - @testing.fails_on_everything_except('firebird', 'maxdb', 'sqlite', 'mysql+pyodbc', '+zxjdbc', 'mysql+oursql') + @testing.fails_on_everything_except('firebird', 'maxdb', 'sqlite', '+pyodbc', '+mxodbc', '+zxjdbc', 'mysql+oursql') def test_raw_qmark(self): for conn in (testing.db, testing.db.connect()): conn.execute("insert into users (user_id, user_name) values (?, ?)", (1,"jack")) conn.execute("insert into users (user_id, user_name) values (?, ?)", [2,"fred"]) - conn.execute("insert into users (user_id, user_name) values (?, ?)", [3,"ed"], [4,"horse"]) - conn.execute("insert into users (user_id, user_name) values (?, ?)", (5,"barney"), (6,"donkey")) + conn.execute("insert into users (user_id, user_name) values (?, ?)", + [3,"ed"], + [4,"horse"]) + conn.execute("insert into users (user_id, user_name) values (?, ?)", + (5,"barney"), (6,"donkey")) conn.execute("insert into users (user_id, user_name) values (?, ?)", 7, 'sally') res = conn.execute("select * from users order by user_id") - assert res.fetchall() == [(1, "jack"), (2, "fred"), (3, "ed"), (4, "horse"), (5, "barney"), (6, "donkey"), (7, 'sally')] + assert res.fetchall() == [(1, "jack"), (2, "fred"), + (3, "ed"), (4, "horse"), + (5, "barney"), (6, "donkey"), + (7, 'sally')] conn.execute("delete from users") @testing.fails_on_everything_except('mysql+mysqldb', 'postgresql') @@ -46,11 +56,15 @@ class ExecuteTest(TestBase): def test_raw_sprintf(self): for conn in (testing.db, testing.db.connect()): conn.execute("insert into users (user_id, user_name) values (%s, %s)", [1,"jack"]) - conn.execute("insert into users (user_id, user_name) values (%s, %s)", [2,"ed"], [3,"horse"]) + conn.execute("insert into users (user_id, user_name) values (%s, %s)", + [2,"ed"], + [3,"horse"]) conn.execute("insert into users (user_id, user_name) values (%s, %s)", 4, 'sally') conn.execute("insert into users (user_id) values (%s)", 5) res = conn.execute("select * from users order by user_id") - assert res.fetchall() == [(1, "jack"), (2, "ed"), (3, "horse"), (4, 'sally'), (5, None)] + assert res.fetchall() == [(1, "jack"), (2, "ed"), + (3, "horse"), (4, 'sally'), + (5, None)] conn.execute("delete from users") # pyformat is supported for mysql, but skipping because a few driver @@ -59,9 +73,12 @@ class ExecuteTest(TestBase): @testing.fails_on_everything_except('postgresql+psycopg2', 'postgresql+pypostgresql') def test_raw_python(self): for conn in (testing.db, testing.db.connect()): - conn.execute("insert into users (user_id, user_name) values (%(id)s, %(name)s)", {'id':1, 'name':'jack'}) - conn.execute("insert into users (user_id, user_name) values (%(id)s, %(name)s)", {'id':2, 'name':'ed'}, {'id':3, 'name':'horse'}) - conn.execute("insert into users (user_id, user_name) values (%(id)s, %(name)s)", id=4, name='sally') + conn.execute("insert into users (user_id, user_name) values (%(id)s, %(name)s)", + {'id':1, 'name':'jack'}) + conn.execute("insert into users (user_id, user_name) values (%(id)s, %(name)s)", + {'id':2, 'name':'ed'}, {'id':3, 'name':'horse'}) + conn.execute("insert into users (user_id, user_name) values (%(id)s, %(name)s)", + id=4, name='sally') res = conn.execute("select * from users order by user_id") assert res.fetchall() == [(1, "jack"), (2, "ed"), (3, "horse"), (4, 'sally')] conn.execute("delete from users") @@ -69,9 +86,12 @@ class ExecuteTest(TestBase): @testing.fails_on_everything_except('sqlite', 'oracle+cx_oracle') def test_raw_named(self): for conn in (testing.db, testing.db.connect()): - conn.execute("insert into users (user_id, user_name) values (:id, :name)", {'id':1, 'name':'jack'}) - conn.execute("insert into users (user_id, user_name) values (:id, :name)", {'id':2, 'name':'ed'}, {'id':3, 'name':'horse'}) - conn.execute("insert into users (user_id, user_name) values (:id, :name)", id=4, name='sally') + conn.execute("insert into users (user_id, user_name) values (:id, :name)", + {'id':1, 'name':'jack'}) + conn.execute("insert into users (user_id, user_name) values (:id, :name)", + {'id':2, 'name':'ed'}, {'id':3, 'name':'horse'}) + conn.execute("insert into users (user_id, user_name) values (:id, :name)", + id=4, name='sally') res = conn.execute("select * from users order by user_id") assert res.fetchall() == [(1, "jack"), (2, "ed"), (3, "horse"), (4, 'sally')] conn.execute("delete from users") @@ -86,8 +106,8 @@ class ExecuteTest(TestBase): def test_empty_insert(self): """test that execute() interprets [] as a list with no params""" - result = testing.db.execute(users.insert().values(user_name=bindparam('name')), []) - eq_(testing.db.execute(users.select()).fetchall(), [ + result = testing.db.execute(users_autoinc.insert().values(user_name=bindparam('name')), []) + eq_(testing.db.execute(users_autoinc.select()).fetchall(), [ (1, None) ]) @@ -124,17 +144,25 @@ class ProxyConnectionTest(TestBase): for engine in ( engines.testing_engine(options=dict(implicit_returning=False, proxy=MyProxy())), - engines.testing_engine(options=dict(implicit_returning=False, proxy=MyProxy(), strategy='threadlocal')) + engines.testing_engine(options=dict( + implicit_returning=False, + proxy=MyProxy(), + strategy='threadlocal')) ): m = MetaData(engine) - t1 = Table('t1', m, Column('c1', Integer, primary_key=True), Column('c2', String(50), default=func.lower('Foo'), primary_key=True)) + t1 = Table('t1', m, + Column('c1', Integer, primary_key=True), + Column('c2', String(50), default=func.lower('Foo'), primary_key=True) + ) m.create_all() try: t1.insert().execute(c1=5, c2='some data') t1.insert().execute(c1=6) - assert engine.execute("select * from t1").fetchall() == [(5, 'some data'), (6, 'foo')] + eq_(engine.execute("select * from t1").fetchall(), + [(5, 'some data'), (6, 'foo')] + ) finally: m.drop_all() @@ -165,7 +193,8 @@ class ProxyConnectionTest(TestBase): cursor = [ ("CREATE TABLE t1", {}, ()), ("INSERT INTO t1 (c1, c2)", {'c2': 'some data', 'c1': 5}, (5, 'some data')), - ("INSERT INTO t1 (c1, c2)", {'c1': 6, "lower_2":"Foo"}, insert2_params), # bind param name 'lower_2' might be incorrect + # bind param name 'lower_2' might be incorrect + ("INSERT INTO t1 (c1, c2)", {'c1': 6, "lower_2":"Foo"}, insert2_params), ("select * from t1", {}, ()), ("DROP TABLE t1", {}, ()) ] diff --git a/test/sql/test_rowcount.py b/test/sql/test_rowcount.py index 6da25b9145..9577b104a6 100644 --- a/test/sql/test_rowcount.py +++ b/test/sql/test_rowcount.py @@ -54,22 +54,19 @@ class FoundRowsTest(TestBase, AssertsExecutionResults): department = employees_table.c.department r = employees_table.update(department=='C').execute(department='Z') print "expecting 3, dialect reports %s" % r.rowcount - if testing.db.dialect.supports_sane_rowcount: - assert r.rowcount == 3 + assert r.rowcount == 3 def test_update_rowcount2(self): # WHERE matches 3, 0 rows changed department = employees_table.c.department r = employees_table.update(department=='C').execute(department='C') print "expecting 3, dialect reports %s" % r.rowcount - if testing.db.dialect.supports_sane_rowcount: - assert r.rowcount == 3 + assert r.rowcount == 3 def test_delete_rowcount(self): # WHERE matches 3, 3 rows deleted department = employees_table.c.department r = employees_table.delete(department=='C').execute() print "expecting 3, dialect reports %s" % r.rowcount - if testing.db.dialect.supports_sane_rowcount: - assert r.rowcount == 3 + assert r.rowcount == 3 -- 2.47.3