]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
inline one_or_none
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 24 May 2020 13:41:51 +0000 (09:41 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 24 May 2020 15:06:21 +0000 (11:06 -0400)
Remove a bunch of unnecessary functions for this case.
add test coverage to ensure uniqueness logic works.

Change-Id: I2e6232c5667a3277b0ec8d7e47085a267f23d75f

lib/sqlalchemy/engine/cursor.py
lib/sqlalchemy/engine/result.py
test/aaa_profiling/test_resultset.py
test/base/test_result.py
test/profiles.txt

index a393f8da76500812c8cdff3f541a5ab1344a09e1..8d1a1bb57ff4f02c886b30a60c78ea5e6d0c0cb8 100644 (file)
@@ -23,7 +23,6 @@ from ..sql import expression
 from ..sql import sqltypes
 from ..sql import util as sql_util
 from ..sql.base import _generative
-from ..sql.base import HasMemoized
 from ..sql.compiler import RM_NAME
 from ..sql.compiler import RM_OBJECTS
 from ..sql.compiler import RM_RENDERED_NAME
@@ -793,7 +792,7 @@ class ResultFetchStrategy(object):
     def yield_per(self, result, num):
         return
 
-    def fetchone(self, result):
+    def fetchone(self, result, hard_close=False):
         raise NotImplementedError()
 
     def fetchmany(self, result, size=None):
@@ -825,7 +824,7 @@ class NoCursorFetchStrategy(ResultFetchStrategy):
     def hard_close(self, result):
         pass
 
-    def fetchone(self, result):
+    def fetchone(self, result, hard_close=False):
         return self._non_result(result, None)
 
     def fetchmany(self, result, size=None):
@@ -927,11 +926,11 @@ class CursorFetchStrategy(ResultFetchStrategy):
             growth_factor=0,
         )
 
-    def fetchone(self, result):
+    def fetchone(self, result, hard_close=False):
         try:
             row = self.dbapi_cursor.fetchone()
             if row is None:
-                result._soft_close()
+                result._soft_close(hard=hard_close)
             return row
         except BaseException as e:
             self.handle_exception(result, e)
@@ -1065,12 +1064,12 @@ class BufferedRowCursorFetchStrategy(CursorFetchStrategy):
         self._rowbuffer.clear()
         super(BufferedRowCursorFetchStrategy, self).hard_close(result)
 
-    def fetchone(self, result):
+    def fetchone(self, result, hard_close=False):
         if not self._rowbuffer:
             self._buffer_rows(result)
             if not self._rowbuffer:
                 try:
-                    result._soft_close()
+                    result._soft_close(hard=hard_close)
                 except BaseException as e:
                     self.handle_exception(result, e)
                 return None
@@ -1137,11 +1136,11 @@ class FullyBufferedCursorFetchStrategy(CursorFetchStrategy):
         self._rowbuffer.clear()
         super(FullyBufferedCursorFetchStrategy, self).hard_close(result)
 
-    def fetchone(self, result):
+    def fetchone(self, result, hard_close=False):
         if self._rowbuffer:
             return self._rowbuffer.popleft()
         else:
-            result._soft_close()
+            result._soft_close(hard=hard_close)
             return None
 
     def fetchmany(self, result, size=None):
@@ -1222,10 +1221,19 @@ class BaseCursorResult(object):
         self.dialect = context.dialect
         self.cursor = context.cursor
         self.connection = context.root_connection
-        self._echo = (
+        self._echo = echo = (
             self.connection._echo and context.engine._should_log_debug()
         )
 
+        if echo:
+            log = self.context.engine.logger.debug
+
+            def log_row(row):
+                log("Row %r", sql_util._repr_row(row))
+                return row
+
+            self._row_logging_fn = log_row
+
         # this is a hook used by dialects to change the strategy,
         # so for the moment we have to keep calling this every time
         # :(
@@ -1616,19 +1624,6 @@ class CursorResult(BaseCursorResult, Result):
     _cursor_metadata = CursorResultMetaData
     _cursor_strategy_cls = CursorFetchStrategy
 
-    @HasMemoized.memoized_attribute
-    def _row_logging_fn(self):
-        if self._echo:
-            log = self.context.engine.logger.debug
-
-            def log_row(row):
-                log("Row %r", sql_util._repr_row(row))
-                return row
-
-            return log_row
-        else:
-            return None
-
     def _fetchiter_impl(self):
         fetchone = self.cursor_strategy.fetchone
 
@@ -1638,8 +1633,8 @@ class CursorResult(BaseCursorResult, Result):
                 break
             yield row
 
-    def _fetchone_impl(self):
-        return self.cursor_strategy.fetchone(self)
+    def _fetchone_impl(self, hard_close=False):
+        return self.cursor_strategy.fetchone(self, hard_close)
 
     def _fetchall_impl(self):
         return self.cursor_strategy.fetchall(self)
index ce844eb408b3efb071e657dd703fbfbd470b515b..4e6b22820d66c17effb80e98b07fd969bddc489a 100644 (file)
@@ -725,10 +725,6 @@ class Result(InPlaceGenerative):
     def _onerow_getter(self):
         make_row = self._row_getter()
 
-        # TODO: this is a lot for results that are only one row.
-        # all of this could be in _only_one_row except for fetchone()
-        # and maybe __next__
-
         post_creational_filter = self._post_creational_filter
 
         if self._unique_filter_state:
@@ -845,7 +841,7 @@ class Result(InPlaceGenerative):
     def _fetchiter_impl(self):
         raise NotImplementedError()
 
-    def _fetchone_impl(self):
+    def _fetchone_impl(self, hard_close=False):
         raise NotImplementedError()
 
     def _fetchall_impl(self):
@@ -943,30 +939,69 @@ class Result(InPlaceGenerative):
         return self._allrow_getter(self)
 
     def _only_one_row(self, raise_for_second_row, raise_for_none):
-        row = self._onerow_getter(self)
-        if row is _NO_ROW:
+        onerow = self._fetchone_impl
+
+        row = onerow(hard_close=True)
+        if row is None:
             if raise_for_none:
-                self._soft_close(hard=True)
                 raise exc.NoResultFound(
                     "No row was found when one was required"
                 )
             else:
                 return None
-        else:
-            if raise_for_second_row:
-                next_row = self._onerow_getter(self)
+
+        make_row = self._row_getter()
+
+        row = make_row(row) if make_row else row
+
+        if raise_for_second_row:
+            if self._unique_filter_state:
+                # for no second row but uniqueness, need to essentially
+                # consume the entire result :(
+                uniques, strategy = self._unique_strategy
+
+                existing_row_hash = strategy(row) if strategy else row
+
+                while True:
+                    next_row = onerow(hard_close=True)
+                    if next_row is None:
+                        next_row = _NO_ROW
+                        break
+
+                    next_row = make_row(next_row) if make_row else next_row
+
+                    if strategy:
+                        if existing_row_hash == strategy(next_row):
+                            continue
+                    elif row == next_row:
+                        continue
+                    # here, we have a row and it's different
+                    break
             else:
-                next_row = _NO_ROW
-            self._soft_close(hard=True)
+                next_row = onerow(hard_close=True)
+                if next_row is None:
+                    next_row = _NO_ROW
+
             if next_row is not _NO_ROW:
+                self._soft_close(hard=True)
                 raise exc.MultipleResultsFound(
                     "Multiple rows were found when exactly one was required"
                     if raise_for_none
                     else "Multiple rows were found when one or none "
                     "was required"
                 )
-            else:
-                return row
+        else:
+            next_row = _NO_ROW
+
+        if not raise_for_second_row:
+            # if we checked for second row then that would have
+            # closed us :)
+            self._soft_close(hard=True)
+        post_creational_filter = self._post_creational_filter
+        if post_creational_filter:
+            row = post_creational_filter(row)
+
+        return row
 
     def first(self):
         """Fetch the first row or None if no row is present.
@@ -1121,12 +1156,13 @@ class IteratorResult(Result):
     def _fetchiter_impl(self):
         return self.iterator
 
-    def _fetchone_impl(self):
-        try:
-            return next(self.iterator)
-        except StopIteration:
-            self._soft_close()
+    def _fetchone_impl(self, hard_close=False):
+        row = next(self.iterator, _NO_ROW)
+        if row is _NO_ROW:
+            self._soft_close(hard=hard_close)
             return None
+        else:
+            return row
 
     def _fetchall_impl(self):
         try:
index 0fdc2b4988f8421d62c4ed4f200858add3987c1c..b22676ad775dd6e393df46ef8638994928f8968a 100644 (file)
@@ -120,6 +120,35 @@ class ResultSetTest(fixtures.TestBase, AssertsExecutionResults):
             for row in conn.execute(t.select()).mappings().fetchall():
                 [row["field%d" % fnum] for fnum in range(NUM_FIELDS)]
 
+    @testing.combinations(
+        (False, 0), (True, 1), (False, 1), (False, 2),
+    )
+    def test_one_or_none(self, one_or_first, rows_present):
+        # TODO: this is not testing the ORM level "scalar_mapping"
+        # mode which has a different performance profile
+        with testing.db.connect() as conn:
+            stmt = t.select()
+            if rows_present == 0:
+                stmt = stmt.where(1 == 0)
+            elif rows_present == 1:
+                stmt = stmt.limit(1)
+
+            result = conn.execute(stmt)
+
+            @profiling.function_call_count()
+            def go():
+                if one_or_first:
+                    result.one()
+                else:
+                    result.first()
+
+            try:
+                go()
+            finally:
+                # hmmmm, connection close context manager does not
+                # seem to be handling this for a profile that skips
+                result.close()
+
     def test_contains_doesnt_compile(self):
         row = t.select().execute().first()
         c1 = Column("some column", Integer) + Column(
index 7628318a567fda9b1e68c58ec233a534a50cdf7f..b179c3462035e3f800902977928d496463c85378 100644 (file)
@@ -399,6 +399,37 @@ class ResultTest(fixtures.TestBase):
 
         eq_(result.all(), [])
 
+    def test_one_unique(self):
+        # assert that one() counts rows after uniquness has been applied.
+        # this would raise if we didnt have unique
+        result = self._fixture(data=[(1, 1, 1), (1, 1, 1)])
+
+        row = result.unique().one()
+        eq_(row, (1, 1, 1))
+
+    def test_one_unique_tricky_one(self):
+        # one() needs to keep consuming rows in order to find a non-unique
+        # one.  unique() really slows things down
+        result = self._fixture(
+            data=[(1, 1, 1), (1, 1, 1), (1, 1, 1), (2, 1, 1)]
+        )
+
+        assert_raises(exc.MultipleResultsFound, result.unique().one)
+
+    def test_one_unique_mapping(self):
+        # assert that one() counts rows after uniquness has been applied.
+        # this would raise if we didnt have unique
+        result = self._fixture(data=[(1, 1, 1), (1, 1, 1)])
+
+        row = result.mappings().unique().one()
+        eq_(row, {"a": 1, "b": 1, "c": 1})
+
+    def test_one_mapping(self):
+        result = self._fixture(num_rows=1)
+
+        row = result.mappings().one()
+        eq_(row, {"a": 1, "b": 1, "c": 1})
+
     def test_one(self):
         result = self._fixture(num_rows=1)
 
index 1831ee9e4ad48e9e08383a848eaae4737379a032..24c4294e348e9be39c1b90146286fa5cf73b671b 100644 (file)
@@ -533,6 +533,30 @@ test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings 3.8_p
 test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings 3.8_sqlite_pysqlite_dbapiunicode_cextensions 2469
 test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 15476
 
+# TEST: test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0]
+
+test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0] 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 14
+test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0] 3.8_sqlite_pysqlite_dbapiunicode_cextensions 15
+test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0] 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 15
+
+# TEST: test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1]
+
+test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 17
+test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] 3.8_sqlite_pysqlite_dbapiunicode_cextensions 16
+test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 18
+
+# TEST: test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2]
+
+test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 17
+test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] 3.8_sqlite_pysqlite_dbapiunicode_cextensions 16
+test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 18
+
+# TEST: test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1]
+
+test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 20
+test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] 3.8_sqlite_pysqlite_dbapiunicode_cextensions 19
+test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] 3.8_sqlite_pysqlite_dbapiunicode_nocextensions 21
+
 # TEST: test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string
 
 test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string 2.7_mssql_pyodbc_dbapiunicode_cextensions 276