From 19a3ae94d701b7c0597fd62f6f9b34650af0fef4 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 2 Aug 2009 23:31:27 +0000 Subject: [PATCH] - allowing resultproxy to autoclose even if implicit returning is used - for now, lastrowid-capable dialects will use pre-execute for any defaults that arent the real "autoincrement"; currently this is letting us treat MSSQL the same as them but we may want to improve upon this --- lib/sqlalchemy/dialects/mssql/base.py | 8 ++-- lib/sqlalchemy/engine/base.py | 56 ++++++++++++++++++++------- lib/sqlalchemy/engine/default.py | 5 +++ lib/sqlalchemy/sql/compiler.py | 3 +- lib/sqlalchemy/test/requires.py | 11 ++++++ test/dialect/test_mssql.py | 36 ++++++++--------- test/orm/inheritance/test_basic.py | 2 - test/sql/test_defaults.py | 26 +++++++++---- test/sql/test_query.py | 19 ++++++++- 9 files changed, 116 insertions(+), 50 deletions(-) diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index e1126532e1..f52e011d17 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -867,8 +867,7 @@ class MSExecutionContext(default.DefaultExecutionContext): if self._enable_identity_insert: self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table)) - - + def get_lastrowid(self): return self._lastrowid @@ -880,7 +879,10 @@ class MSExecutionContext(default.DefaultExecutionContext): pass def get_result_proxy(self): - return self._result_proxy or base.ResultProxy(self) + if self._result_proxy: + return self._result_proxy + else: + return base.ResultProxy(self) class MSSQLCompiler(compiler.SQLCompiler): diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index ba3880ca2b..538ab88918 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1086,7 +1086,7 @@ class Connection(Connectable): if context.should_autocommit and not self.in_transaction(): self._commit_impl() - return context.get_result_proxy() + return context.get_result_proxy()._autoclose() def _handle_dbapi_exception(self, e, statement, parameters, cursor, context): if getattr(self, '_reentrant_error', False): @@ -1615,9 +1615,26 @@ class ResultProxy(object): self.connection = context.root_connection self._echo = context.engine._should_log_info self._init_metadata() - + @util.memoized_property def rowcount(self): + """Return the 'rowcount' for this result. + + The 'rowcount' reports the number of rows affected + by an UPDATE or DELETE statement. It has *no* other + uses and is not intended to provide the number of rows + present from a SELECT. + + Additionally, this value is only meaningful if the + dialect's supports_sane_rowcount flag is True for + single-parameter executions, or supports_sane_multi_rowcount + is true for multiple parameter executions - otherwise + results are undefined. + + rowcount may not work at this time for a statement + that uses ``returning()``. + + """ return self.context.rowcount @property @@ -1626,7 +1643,8 @@ class ResultProxy(object): This is a DBAPI specific method and is only functional for those backends which support it, for statements - where it is appropriate. + where it is appropriate. It's behavior is not + consistent across backends. Usage of this method is normally unnecessary; the last_inserted_ids() method provides a @@ -1641,20 +1659,27 @@ class ResultProxy(object): return self.context.out_parameters def _cursor_description(self): - metadata = self.cursor.description - if metadata is None: - return - else: - return [(r[0], r[1]) for r in metadata] + return self.cursor.description - def _init_metadata(self): - - metadata = self._cursor_description() - if metadata is None: + def _autoclose(self): + if self._metadata is None: # no results, get rowcount # (which requires open cursor on some DB's such as firebird), self.rowcount self.close() # autoclose + elif self.context.isinsert and \ + not self.context._is_explicit_returning: + # an insert, no explicit returning(), may need + # to fetch rows which were created via implicit + # returning, then close + self.context.last_inserted_ids(self) + self.close() + + return self + + def _init_metadata(self): + self._metadata = metadata = self._cursor_description() + if metadata is None: return self._props = util.populate_column_dict(None) @@ -1663,7 +1688,7 @@ class ResultProxy(object): typemap = self.dialect.dbapi_type_map - for i, (colname, coltype) in enumerate(metadata): + for i, (colname, coltype) in enumerate(m[0:2] for m in metadata): if self.dialect.description_encoding: colname = colname.decode(self.dialect.description_encoding) @@ -1738,6 +1763,9 @@ class ResultProxy(object): """Close this ResultProxy. Closes the underlying DBAPI cursor corresponding to the execution. + + Note that any data cached within this ResultProxy is still available. + For some types of results, this may include buffered rows. If this ResultProxy was generated from an implicit execution, the underlying Connection will also be closed (returns the @@ -2000,8 +2028,8 @@ class FullyBufferedResultProxy(ResultProxy): """ def _init_metadata(self): - self.__rowbuffer = self._buffer_rows() super(FullyBufferedResultProxy, self)._init_metadata() + self.__rowbuffer = self._buffer_rows() def _buffer_rows(self): return self.cursor.fetchall() diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index bede3b7018..6f468540c6 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -262,6 +262,11 @@ class DefaultExecutionContext(base.ExecutionContext): self.statement = self.compiled = None self.isinsert = self.isupdate = self.isdelete = self.executemany = self.should_autocommit = False self.cursor = self.create_cursor() + + @property + def _is_explicit_returning(self): + return self.compiled and \ + getattr(self.compiled.statement, '_returning', False) @property def connection(self): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index c981785734..a4ab763ea8 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -802,7 +802,8 @@ class SQLCompiler(engine.Compiled): # then implicit_returning/supports sequence/doesnt if c.primary_key and \ ( - self.dialect.preexecute_pk_sequences or + self.dialect.preexecute_pk_sequences or + c is not stmt.table._autoincrement_column or implicit_returning ) and \ not self.inline and \ diff --git a/lib/sqlalchemy/test/requires.py b/lib/sqlalchemy/test/requires.py index 5da8277948..f3f4ec1911 100644 --- a/lib/sqlalchemy/test/requires.py +++ b/lib/sqlalchemy/test/requires.py @@ -131,6 +131,17 @@ def subqueries(fn): exclude('mysql', '<', (4, 1, 1), 'no subquery support'), ) +def returning(fn): + return _chain_decorators_on( + fn, + no_support('access', 'not supported by database'), + no_support('sqlite', 'not supported by database'), + no_support('mysql', 'not supported by database'), + no_support('maxdb', 'not supported by database'), + no_support('sybase', 'not supported by database'), + no_support('informix', 'not supported by database'), + ) + def two_phase_transactions(fn): """Target database must support two-phase transactions.""" return _chain_decorators_on( diff --git a/test/dialect/test_mssql.py b/test/dialect/test_mssql.py index 05031c068b..989b538bd6 100644 --- a/test/dialect/test_mssql.py +++ b/test/dialect/test_mssql.py @@ -681,26 +681,22 @@ class TypesTest(TestBase, AssertsExecutionResults, ComparesTables): ) metadata.create_all() - try: - test_items = [decimal.Decimal(d) for d in '1500000.00000000000000000000', - '-1500000.00000000000000000000', '1500000', - '0.0000000000000000002', '0.2', '-0.0000000000000000002', '-2E-2', - '156666.458923543', '-156666.458923543', '1', '-1', '-1234', '1234', - '2E-12', '4E8', '3E-6', '3E-7', '4.1', '1E-1', '1E-2', '1E-3', - '1E-4', '1E-5', '1E-6', '1E-7', '1E-1', '1E-8', '0.2732E2', '-0.2432E2', '4.35656E2', - '-02452E-2', '45125E-2', - '1234.58965E-2', '1.521E+15', '-1E-25', '1E-25', '1254E-25', '-1203E-25', - '0', '-0.00', '-0', '4585E12', '000000000000000000012', '000000000000.32E12', - '00000000000000.1E+12', '000000000000.2E-32'] - - for value in test_items: - numeric_table.insert().execute(numericcol=value) - - for value in select([numeric_table.c.numericcol]).execute(): - assert value[0] in test_items, "%s not in test_items" % value[0] - - except Exception, e: - raise e + test_items = [decimal.Decimal(d) for d in '1500000.00000000000000000000', + '-1500000.00000000000000000000', '1500000', + '0.0000000000000000002', '0.2', '-0.0000000000000000002', '-2E-2', + '156666.458923543', '-156666.458923543', '1', '-1', '-1234', '1234', + '2E-12', '4E8', '3E-6', '3E-7', '4.1', '1E-1', '1E-2', '1E-3', + '1E-4', '1E-5', '1E-6', '1E-7', '1E-1', '1E-8', '0.2732E2', '-0.2432E2', '4.35656E2', + '-02452E-2', '45125E-2', + '1234.58965E-2', '1.521E+15', '-1E-25', '1E-25', '1254E-25', '-1203E-25', + '0', '-0.00', '-0', '4585E12', '000000000000000000012', '000000000000.32E12', + '00000000000000.1E+12', '000000000000.2E-32'] + + for value in test_items: + numeric_table.insert().execute(numericcol=value) + + for value in select([numeric_table.c.numericcol]).execute(): + assert value[0] in test_items, "%s not in test_items" % value[0] def test_float(self): float_table = Table('float_table', metadata, diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index e9cd6093d2..b2e00de359 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -450,7 +450,6 @@ class VersioningTest(_base.MappedTest): Column('parent', Integer, ForeignKey('base.id')) ) - @testing.fails_on('mssql', 'FIXME: the flush still happens with the concurrency issue.') @engines.close_open_connections def test_save_update(self): class Base(_fixtures.Base): @@ -500,7 +499,6 @@ class VersioningTest(_base.MappedTest): s2.subdata = 'sess2 subdata' sess2.flush() - @testing.fails_on('mssql', 'FIXME: the flush still happens with the concurrency issue.') def test_delete(self): class Base(_fixtures.Base): pass diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py index f2bc5a53b4..87a1a24ddf 100644 --- a/test/sql/test_defaults.py +++ b/test/sql/test_defaults.py @@ -146,7 +146,7 @@ class DefaultTest(testing.TestBase): assert_raises_message(sa.exc.ArgumentError, ex_msg, sa.ColumnDefault, fn) - + def test_arg_signature(self): def fn1(): pass def fn2(): pass @@ -369,18 +369,28 @@ class PKDefaultTest(_base.TablesTest): Column('id', Integer, primary_key=True, default=sa.select([func.max(t2.c.nextid)]).as_scalar()), Column('data', String(30))) - + + @testing.requires.returning + def test_with_implicit_returning(self): + self._test(True) + + def test_regular(self): + self._test(False) + @testing.resolve_artifact_names - def test_basic(self): - t2.insert().execute(nextid=1) - r = t1.insert().execute(data='hi') + def _test(self, returning): + if not returning and not testing.db.dialect.implicit_returning: + engine = testing.db + else: + engine = engines.testing_engine(options={'implicit_returning':returning}) + engine.execute(t2.insert(), nextid=1) + r = engine.execute(t1.insert(), data='hi') eq_([1], r.last_inserted_ids()) - t2.insert().execute(nextid=2) - r = t1.insert().execute(data='there') + engine.execute(t2.insert(), nextid=2) + r = engine.execute(t1.insert(), data='there') eq_([2], r.last_inserted_ids()) - class PKIncrementTest(_base.TablesTest): run_define_tables = 'each' diff --git a/test/sql/test_query.py b/test/sql/test_query.py index b3a9eb0ccb..2e56e0d3ec 100644 --- a/test/sql/test_query.py +++ b/test/sql/test_query.py @@ -80,7 +80,7 @@ class QueryTest(TestBase): ret[c.key] = row[c] return ret - if testing.against('firebird', 'postgres', 'oracle', 'mssql'): + if testing.against('firebird', 'postgresql', 'oracle', 'mssql'): test_engines = [ engines.testing_engine(options={'implicit_returning':False}), engines.testing_engine(options={'implicit_returning':True}), @@ -166,7 +166,22 @@ class QueryTest(TestBase): r = t6.insert().values(manual_id=id).execute() eq_(r.last_inserted_ids(), [12, 1]) - + def test_autoclose_on_insert(self): + if testing.against('firebird', 'postgresql', 'oracle', 'mssql'): + test_engines = [ + engines.testing_engine(options={'implicit_returning':False}), + engines.testing_engine(options={'implicit_returning':True}), + ] + else: + test_engines = [testing.db] + + for engine in test_engines: + + r = engine.execute(users.insert(), + {'user_name':'jack'}, + ) + assert r.closed + def test_row_iteration(self): users.insert().execute( {'user_id':7, 'user_name':'jack'}, -- 2.47.3