]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Improve implementation of server side cursor in asyncpg
authorFederico Caselli <cfederico87@gmail.com>
Sat, 4 May 2024 20:12:28 +0000 (22:12 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Mon, 6 May 2024 17:40:04 +0000 (19:40 +0200)
Change-Id: I36d0ff5ccea7fbf46cabcfeae1492b9a90b7f68b

lib/sqlalchemy/dialects/postgresql/asyncpg.py

index c9a39eb3eb3ce5ea7208237a2bfbeb6eb0fef9d9..66cdeb84639ea044675582ce6c4652a0a851457d 100644 (file)
@@ -172,7 +172,7 @@ client using this setting passed to :func:`_asyncio.create_async_engine`::
 from __future__ import annotations
 
 import asyncio
-import collections
+from collections import deque
 import decimal
 import json as _py_json
 import re
@@ -530,7 +530,7 @@ class AsyncAdapt_asyncpg_cursor(AsyncAdapt_dbapi_cursor):
         self._adapt_connection = adapt_connection
         self._connection = adapt_connection._connection
         self._cursor = None
-        self._rows = collections.deque()
+        self._rows = deque()
         self._description = None
         self._arraysize = 1
         self._rowcount = -1
@@ -574,9 +574,7 @@ class AsyncAdapt_asyncpg_cursor(AsyncAdapt_dbapi_cursor):
                     self._cursor = await prepared_stmt.cursor(*parameters)
                     self._rowcount = -1
                 else:
-                    self._rows = collections.deque(
-                        await prepared_stmt.fetch(*parameters)
-                    )
+                    self._rows = deque(await prepared_stmt.fetch(*parameters))
                     status = prepared_stmt.get_statusmsg()
 
                     reg = re.match(
@@ -643,24 +641,21 @@ class AsyncAdapt_asyncpg_ss_cursor(
 
     def __init__(self, adapt_connection):
         super().__init__(adapt_connection)
-        self._rowbuffer = None
+        self._rowbuffer = deque()
 
     def close(self):
         self._cursor = None
-        self._rowbuffer = None
+        self._rowbuffer.clear()
 
     def _buffer_rows(self):
         assert self._cursor is not None
         new_rows = await_(self._cursor.fetch(50))
-        self._rowbuffer = collections.deque(new_rows)
+        self._rowbuffer.extend(new_rows)
 
     def __aiter__(self):
         return self
 
     async def __anext__(self):
-        if not self._rowbuffer:
-            self._buffer_rows()
-
         while True:
             while self._rowbuffer:
                 yield self._rowbuffer.popleft()
@@ -683,22 +678,17 @@ class AsyncAdapt_asyncpg_ss_cursor(
         if not self._rowbuffer:
             self._buffer_rows()
 
-        assert self._rowbuffer is not None
         assert self._cursor is not None
-
-        buf = list(self._rowbuffer)
-        lb = len(buf)
+        rb = self._rowbuffer
+        lb = len(rb)
         if size > lb:
-            buf.extend(await_(self._cursor.fetch(size - lb)))
+            rb.extend(await_(self._cursor.fetch(size - lb)))
 
-        result = buf[0:size]
-        self._rowbuffer = collections.deque(buf[size:])
-        return result
+        return [rb.popleft() for _ in range(min(size, len(rb)))]
 
     def fetchall(self):
-        assert self._rowbuffer is not None
-
-        ret = list(self._rowbuffer) + list(await_(self._all()))
+        ret = list(self._rowbuffer)
+        ret.extend(await_(self._all()))
         self._rowbuffer.clear()
         return ret