]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- allowing resultproxy to autoclose even if implicit returning is used
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 2 Aug 2009 23:31:27 +0000 (23:31 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 2 Aug 2009 23:31:27 +0000 (23:31 +0000)
- for now, lastrowid-capable dialects will use pre-execute for any defaults that arent the real "autoincrement";
currently this is letting us treat MSSQL the same as them but we may want to improve upon this

lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/test/requires.py
test/dialect/test_mssql.py
test/orm/inheritance/test_basic.py
test/sql/test_defaults.py
test/sql/test_query.py

index e1126532e11f0ea5ea228951fe8b2c5fd6e1d565..f52e011d17b709df1123696c3d60ea36e77c0855 100644 (file)
@@ -867,8 +867,7 @@ class MSExecutionContext(default.DefaultExecutionContext):
 
         if self._enable_identity_insert:
             self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table))
-
-    
+        
     def get_lastrowid(self):
         return self._lastrowid
         
@@ -880,7 +879,10 @@ class MSExecutionContext(default.DefaultExecutionContext):
                 pass
 
     def get_result_proxy(self):
-        return self._result_proxy or base.ResultProxy(self)
+        if self._result_proxy:
+            return self._result_proxy
+        else:
+            return base.ResultProxy(self)
 
 class MSSQLCompiler(compiler.SQLCompiler):
 
index ba3880ca2b988fa285810782b8324a0b35bbcd3a..538ab88918bad7a85ff034c923ac439cec5b603c 100644 (file)
@@ -1086,7 +1086,7 @@ class Connection(Connectable):
         if context.should_autocommit and not self.in_transaction():
             self._commit_impl()
             
-        return context.get_result_proxy()
+        return context.get_result_proxy()._autoclose()
         
     def _handle_dbapi_exception(self, e, statement, parameters, cursor, context):
         if getattr(self, '_reentrant_error', False):
@@ -1615,9 +1615,26 @@ class ResultProxy(object):
         self.connection = context.root_connection
         self._echo = context.engine._should_log_info
         self._init_metadata()
-
+            
     @util.memoized_property
     def rowcount(self):
+        """Return the 'rowcount' for this result.
+        
+        The 'rowcount' reports the number of rows affected
+        by an UPDATE or DELETE statement.  It has *no* other
+        uses and is not intended to provide the number of rows
+        present from a SELECT.
+        
+        Additionally, this value is only meaningful if the
+        dialect's supports_sane_rowcount flag is True for
+        single-parameter executions, or supports_sane_multi_rowcount
+        is true for multiple parameter executions - otherwise
+        results are undefined.
+        
+        rowcount may not work at this time for a statement
+        that uses ``returning()``.
+        
+        """
         return self.context.rowcount
 
     @property
@@ -1626,7 +1643,8 @@ class ResultProxy(object):
         
         This is a DBAPI specific method and is only functional
         for those backends which support it, for statements
-        where it is appropriate.
+        where it is appropriate.  It's behavior is not 
+        consistent across backends.
         
         Usage of this method is normally unnecessary; the
         last_inserted_ids() method provides a
@@ -1641,20 +1659,27 @@ class ResultProxy(object):
         return self.context.out_parameters
     
     def _cursor_description(self):
-        metadata = self.cursor.description
-        if metadata is None:
-            return
-        else:
-            return [(r[0], r[1]) for r in metadata]
+        return self.cursor.description
             
-    def _init_metadata(self):
-        
-        metadata = self._cursor_description()
-        if metadata is None:
+    def _autoclose(self):
+        if self._metadata is None:
             # no results, get rowcount 
             # (which requires open cursor on some DB's such as firebird),
             self.rowcount
             self.close() # autoclose
+        elif self.context.isinsert and \
+            not self.context._is_explicit_returning:
+            # an insert, no explicit returning(), may need
+            # to fetch rows which were created via implicit 
+            # returning, then close
+            self.context.last_inserted_ids(self)
+            self.close()
+            
+        return self
+            
+    def _init_metadata(self):
+        self._metadata = metadata = self._cursor_description()
+        if metadata is None:
             return
         
         self._props = util.populate_column_dict(None)
@@ -1663,7 +1688,7 @@ class ResultProxy(object):
 
         typemap = self.dialect.dbapi_type_map
 
-        for i, (colname, coltype) in enumerate(metadata):
+        for i, (colname, coltype) in enumerate(m[0:2] for m in metadata):
 
             if self.dialect.description_encoding:
                 colname = colname.decode(self.dialect.description_encoding)
@@ -1738,6 +1763,9 @@ class ResultProxy(object):
         """Close this ResultProxy.
 
         Closes the underlying DBAPI cursor corresponding to the execution.
+        
+        Note that any data cached within this ResultProxy is still available.
+        For some types of results, this may include buffered rows.
 
         If this ResultProxy was generated from an implicit execution,
         the underlying Connection will also be closed (returns the
@@ -2000,8 +2028,8 @@ class FullyBufferedResultProxy(ResultProxy):
     
     """
     def _init_metadata(self):
-        self.__rowbuffer = self._buffer_rows()
         super(FullyBufferedResultProxy, self)._init_metadata()
+        self.__rowbuffer = self._buffer_rows()
         
     def _buffer_rows(self):
         return self.cursor.fetchall()
index bede3b701876718ae8248750e203e66bb5a7f5c5..6f468540c66b8f0b5c2a37a0714b81bbcf1453e9 100644 (file)
@@ -262,6 +262,11 @@ class DefaultExecutionContext(base.ExecutionContext):
             self.statement = self.compiled = None
             self.isinsert = self.isupdate = self.isdelete = self.executemany = self.should_autocommit = False
             self.cursor = self.create_cursor()
+    
+    @property
+    def _is_explicit_returning(self):
+        return self.compiled and \
+            getattr(self.compiled.statement, '_returning', False)
 
     @property
     def connection(self):
index c981785734584059e09e88db35b8b7c22ce01bf0..a4ab763ea840d9e8db81064f9aea1e4d755f00be 100644 (file)
@@ -802,7 +802,8 @@ class SQLCompiler(engine.Compiled):
                     # then implicit_returning/supports sequence/doesnt
                     if c.primary_key and \
                         (
-                            self.dialect.preexecute_pk_sequences or 
+                            self.dialect.preexecute_pk_sequences or
+                            c is not stmt.table._autoincrement_column or 
                             implicit_returning
                         ) and \
                         not self.inline and \
index 5da8277948f0f9d003e6024027cab0b06ac6d248..f3f4ec1911c9f0a7243408d282429dee322ba1cd 100644 (file)
@@ -131,6 +131,17 @@ def subqueries(fn):
         exclude('mysql', '<', (4, 1, 1), 'no subquery support'),
         )
 
+def returning(fn):
+    return _chain_decorators_on(
+        fn,
+        no_support('access', 'not supported by database'),
+        no_support('sqlite', 'not supported by database'),
+        no_support('mysql', 'not supported by database'),
+        no_support('maxdb', 'not supported by database'),
+        no_support('sybase', 'not supported by database'),
+        no_support('informix', 'not supported by database'),
+    )
+    
 def two_phase_transactions(fn):
     """Target database must support two-phase transactions."""
     return _chain_decorators_on(
index 05031c068b74707973feafe60b9548153d7f565e..989b538bd6c7929798c7428c5e9cdfe7f76743f3 100644 (file)
@@ -681,26 +681,22 @@ class TypesTest(TestBase, AssertsExecutionResults, ComparesTables):
         )
         metadata.create_all()
 
-        try:
-            test_items = [decimal.Decimal(d) for d in '1500000.00000000000000000000',
-                          '-1500000.00000000000000000000', '1500000',
-                          '0.0000000000000000002', '0.2', '-0.0000000000000000002', '-2E-2',
-                          '156666.458923543', '-156666.458923543', '1', '-1', '-1234', '1234',
-                          '2E-12', '4E8', '3E-6', '3E-7', '4.1', '1E-1', '1E-2', '1E-3',
-                          '1E-4', '1E-5', '1E-6', '1E-7', '1E-1', '1E-8', '0.2732E2', '-0.2432E2', '4.35656E2',
-                          '-02452E-2', '45125E-2',
-                          '1234.58965E-2', '1.521E+15', '-1E-25', '1E-25', '1254E-25', '-1203E-25',
-                          '0', '-0.00', '-0', '4585E12', '000000000000000000012', '000000000000.32E12',
-                          '00000000000000.1E+12', '000000000000.2E-32']
-
-            for value in test_items:
-                numeric_table.insert().execute(numericcol=value)
-
-            for value in select([numeric_table.c.numericcol]).execute():
-                assert value[0] in test_items, "%s not in test_items" % value[0]
-
-        except Exception, e:
-            raise e
+        test_items = [decimal.Decimal(d) for d in '1500000.00000000000000000000',
+                      '-1500000.00000000000000000000', '1500000',
+                      '0.0000000000000000002', '0.2', '-0.0000000000000000002', '-2E-2',
+                      '156666.458923543', '-156666.458923543', '1', '-1', '-1234', '1234',
+                      '2E-12', '4E8', '3E-6', '3E-7', '4.1', '1E-1', '1E-2', '1E-3',
+                      '1E-4', '1E-5', '1E-6', '1E-7', '1E-1', '1E-8', '0.2732E2', '-0.2432E2', '4.35656E2',
+                      '-02452E-2', '45125E-2',
+                      '1234.58965E-2', '1.521E+15', '-1E-25', '1E-25', '1254E-25', '-1203E-25',
+                      '0', '-0.00', '-0', '4585E12', '000000000000000000012', '000000000000.32E12',
+                      '00000000000000.1E+12', '000000000000.2E-32']
+
+        for value in test_items:
+            numeric_table.insert().execute(numericcol=value)
+
+        for value in select([numeric_table.c.numericcol]).execute():
+            assert value[0] in test_items, "%s not in test_items" % value[0]
 
     def test_float(self):
         float_table = Table('float_table', metadata,
index e9cd6093d2d57048f85fd265d68e576458dc73bc..b2e00de3598261c24b53847425964b128ea35cdb 100644 (file)
@@ -450,7 +450,6 @@ class VersioningTest(_base.MappedTest):
             Column('parent', Integer, ForeignKey('base.id'))
             )
 
-    @testing.fails_on('mssql', 'FIXME: the flush still happens with the concurrency issue.')
     @engines.close_open_connections
     def test_save_update(self):
         class Base(_fixtures.Base):
@@ -500,7 +499,6 @@ class VersioningTest(_base.MappedTest):
         s2.subdata = 'sess2 subdata'
         sess2.flush()
 
-    @testing.fails_on('mssql', 'FIXME: the flush still happens with the concurrency issue.')
     def test_delete(self):
         class Base(_fixtures.Base):
             pass
index f2bc5a53b4b4dc1b8222426422750214ad42c658..87a1a24ddff26c7e04f59e3156c45d4970a4809b 100644 (file)
@@ -146,7 +146,7 @@ class DefaultTest(testing.TestBase):
             assert_raises_message(sa.exc.ArgumentError,
                                      ex_msg,
                                      sa.ColumnDefault, fn)
-
+    
     def test_arg_signature(self):
         def fn1(): pass
         def fn2(): pass
@@ -369,18 +369,28 @@ class PKDefaultTest(_base.TablesTest):
               Column('id', Integer, primary_key=True,
                      default=sa.select([func.max(t2.c.nextid)]).as_scalar()),
               Column('data', String(30)))
-
+    
+    @testing.requires.returning
+    def test_with_implicit_returning(self):
+        self._test(True)
+        
+    def test_regular(self):
+        self._test(False)
+        
     @testing.resolve_artifact_names
-    def test_basic(self):
-        t2.insert().execute(nextid=1)
-        r = t1.insert().execute(data='hi')
+    def _test(self, returning):
+        if not returning and not testing.db.dialect.implicit_returning:
+            engine = testing.db
+        else:
+            engine = engines.testing_engine(options={'implicit_returning':returning})
+        engine.execute(t2.insert(), nextid=1)
+        r = engine.execute(t1.insert(), data='hi')
         eq_([1], r.last_inserted_ids())
 
-        t2.insert().execute(nextid=2)
-        r = t1.insert().execute(data='there')
+        engine.execute(t2.insert(), nextid=2)
+        r = engine.execute(t1.insert(), data='there')
         eq_([2], r.last_inserted_ids())
 
-
 class PKIncrementTest(_base.TablesTest):
     run_define_tables = 'each'
 
index b3a9eb0ccbe2adbb48b17b9af567055753487cb0..2e56e0d3ec66034b7857386ea4cfe6552bf981d3 100644 (file)
@@ -80,7 +80,7 @@ class QueryTest(TestBase):
                     ret[c.key] = row[c]
             return ret
 
-        if testing.against('firebird', 'postgres', 'oracle', 'mssql'):
+        if testing.against('firebird', 'postgresql', 'oracle', 'mssql'):
             test_engines = [
                 engines.testing_engine(options={'implicit_returning':False}),
                 engines.testing_engine(options={'implicit_returning':True}),
@@ -166,7 +166,22 @@ class QueryTest(TestBase):
         r = t6.insert().values(manual_id=id).execute()
         eq_(r.last_inserted_ids(), [12, 1])
 
-
+    def test_autoclose_on_insert(self):
+        if testing.against('firebird', 'postgresql', 'oracle', 'mssql'):
+            test_engines = [
+                engines.testing_engine(options={'implicit_returning':False}),
+                engines.testing_engine(options={'implicit_returning':True}),
+            ]
+        else:
+            test_engines = [testing.db]
+            
+        for engine in test_engines:
+        
+            r = engine.execute(users.insert(), 
+                {'user_name':'jack'},
+            )
+            assert r.closed
+        
     def test_row_iteration(self):
         users.insert().execute(
             {'user_id':7, 'user_name':'jack'},