From 401533e218c5e7950d65489863e0207d4787e50d Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Sat, 4 May 2024 22:12:28 +0200 Subject: [PATCH] Improve implementation of server side cursor in asyncpg Change-Id: I36d0ff5ccea7fbf46cabcfeae1492b9a90b7f68b (cherry picked from commit e3571e1d4b4d34a250886a8967a9b1339f0c68a7) --- lib/sqlalchemy/dialects/postgresql/asyncpg.py | 28 ++++++++----------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index df2656de2a..12e711f52e 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -171,7 +171,7 @@ client using this setting passed to :func:`_asyncio.create_async_engine`:: from __future__ import annotations -import collections +from collections import deque import decimal import json as _py_json import re @@ -611,23 +611,21 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor): def __init__(self, adapt_connection): super().__init__(adapt_connection) - self._rowbuffer = None + self._rowbuffer = deque() def close(self): self._cursor = None - self._rowbuffer = None + self._rowbuffer.clear() def _buffer_rows(self): + assert self._cursor is not None new_rows = self._adapt_connection.await_(self._cursor.fetch(50)) - self._rowbuffer = collections.deque(new_rows) + self._rowbuffer.extend(new_rows) def __aiter__(self): return self async def __anext__(self): - if not self._rowbuffer: - self._buffer_rows() - while True: while self._rowbuffer: yield self._rowbuffer.popleft() @@ -650,21 +648,19 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor): if not self._rowbuffer: self._buffer_rows() - buf = list(self._rowbuffer) - lb = len(buf) + assert self._cursor is not None + rb = self._rowbuffer + lb = len(rb) if size > lb: - buf.extend( + rb.extend( self._adapt_connection.await_(self._cursor.fetch(size - lb)) ) - result = buf[0:size] - self._rowbuffer = collections.deque(buf[size:]) - return result + return [rb.popleft() for _ in range(min(size, len(rb)))] def fetchall(self): - ret = list(self._rowbuffer) + list( - self._adapt_connection.await_(self._all()) - ) + ret = list(self._rowbuffer) + ret.extend(self._adapt_connection.await_(self._all())) self._rowbuffer.clear() return ret -- 2.47.2