]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Improve implementation of server side cursor in asyncpg
authorFederico Caselli <cfederico87@gmail.com>
Sat, 4 May 2024 20:12:28 +0000 (22:12 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Mon, 6 May 2024 17:41:26 +0000 (19:41 +0200)
Change-Id: I36d0ff5ccea7fbf46cabcfeae1492b9a90b7f68b
(cherry picked from commit e3571e1d4b4d34a250886a8967a9b1339f0c68a7)

lib/sqlalchemy/dialects/postgresql/asyncpg.py

index df2656de2a8cad89d58811dc7e9fb86bde34f488..12e711f52e23685c5cbfd5e88aee9615836aa39a 100644 (file)
@@ -171,7 +171,7 @@ client using this setting passed to :func:`_asyncio.create_async_engine`::
 
 from __future__ import annotations
 
-import collections
+from collections import deque
 import decimal
 import json as _py_json
 import re
@@ -611,23 +611,21 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor):
 
     def __init__(self, adapt_connection):
         super().__init__(adapt_connection)
-        self._rowbuffer = None
+        self._rowbuffer = deque()
 
     def close(self):
         self._cursor = None
-        self._rowbuffer = None
+        self._rowbuffer.clear()
 
     def _buffer_rows(self):
+        assert self._cursor is not None
         new_rows = self._adapt_connection.await_(self._cursor.fetch(50))
-        self._rowbuffer = collections.deque(new_rows)
+        self._rowbuffer.extend(new_rows)
 
     def __aiter__(self):
         return self
 
     async def __anext__(self):
-        if not self._rowbuffer:
-            self._buffer_rows()
-
         while True:
             while self._rowbuffer:
                 yield self._rowbuffer.popleft()
@@ -650,21 +648,19 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor):
         if not self._rowbuffer:
             self._buffer_rows()
 
-        buf = list(self._rowbuffer)
-        lb = len(buf)
+        assert self._cursor is not None
+        rb = self._rowbuffer
+        lb = len(rb)
         if size > lb:
-            buf.extend(
+            rb.extend(
                 self._adapt_connection.await_(self._cursor.fetch(size - lb))
             )
 
-        result = buf[0:size]
-        self._rowbuffer = collections.deque(buf[size:])
-        return result
+        return [rb.popleft() for _ in range(min(size, len(rb)))]
 
     def fetchall(self):
-        ret = list(self._rowbuffer) + list(
-            self._adapt_connection.await_(self._all())
-        )
+        ret = list(self._rowbuffer)
+        ret.extend(self._adapt_connection.await_(self._all()))
         self._rowbuffer.clear()
         return ret