]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor: expose optimized functions directly from the origin modules
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 15 Aug 2022 00:49:24 +0000 (02:49 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 15 Aug 2022 13:38:13 +0000 (15:38 +0200)
Don't check for the availability of the optimized module in every module
using potentially optimized functions.

psycopg/psycopg/connection.py
psycopg/psycopg/cursor.py
psycopg/psycopg/generators.py
psycopg/psycopg/server_cursor.py
tests/test_module.py

index 4526b22faa6e8a302e5a1dbc1f5ac72d0f7dbe3a..9c1580dc2674a9630f537fb5591b5957867e3f72 100644 (file)
@@ -27,9 +27,8 @@ from .rows import Row, RowFactory, tuple_row, TupleRow
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
 from .cursor import Cursor
-from ._cmodule import _psycopg
 from .conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo
-from .generators import notifies
+from .generators import connect, execute, notifies
 from ._encodings import pgconn_encoding
 from ._preparing import PrepareManager
 from .transaction import Transaction
@@ -41,23 +40,10 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger("psycopg")
 
-connect: Callable[[str], PQGenConn["PGconn"]]
-execute: Callable[["PGconn"], PQGen[List["PGresult"]]]
-
 # Row Type variable for Cursor (when it needs to be distinguished from the
 # connection's one)
 CursorRow = TypeVar("CursorRow")
 
-if _psycopg:
-    connect = _psycopg.connect
-    execute = _psycopg.execute
-
-else:
-    from . import generators
-
-    connect = generators.connect
-    execute = generators.execute
-
 
 class Notify(NamedTuple):
     """An asynchronous notification received from the database."""
index ebd1a8e6d2dbd87b19ad24c9df8f3f35d5579b1c..987d2f5c053ca71932e9fb72998f65927a434dc3 100644 (file)
@@ -19,28 +19,16 @@ from .abc import ConnectionType, Query, Params, PQGen
 from .copy import Copy
 from .rows import Row, RowMaker, RowFactory
 from ._column import Column
-from ._cmodule import _psycopg
 from ._queries import PostgresQuery
 from ._encodings import pgconn_encoding
 from ._preparing import Prepare
+from .generators import execute, fetch, send
 
 if TYPE_CHECKING:
     from .abc import Transformer
     from .pq.abc import PGconn, PGresult
     from .connection import Connection
 
-if _psycopg:
-    execute = _psycopg.execute
-    fetch = _psycopg.fetch
-    send = _psycopg.send
-
-else:
-    from . import generators
-
-    execute = generators.execute
-    fetch = generators.fetch
-    send = generators.send
-
 _C = TypeVar("_C", bound="Cursor[Any]")
 
 ACTIVE = pq.TransactionStatus.ACTIVE
index b34a6aecb1fd33afdea3c954bcc2d29b3c668e16..fbd5095f116e7a19e7bc7a4ac970363fb95c037b 100644 (file)
@@ -24,12 +24,13 @@ from .pq import ConnStatus, PollingStatus, ExecStatus
 from .abc import PQGen, PQGenConn
 from .pq.abc import PGconn, PGresult
 from .waiting import Wait, Ready
+from ._cmodule import _psycopg
 from ._encodings import pgconn_encoding, conninfo_encoding
 
 logger = logging.getLogger(__name__)
 
 
-def connect(conninfo: str) -> PQGenConn[PGconn]:
+def _connect(conninfo: str) -> PQGenConn[PGconn]:
     """
     Generator to create a database connection without blocking.
 
@@ -61,7 +62,7 @@ def connect(conninfo: str) -> PQGenConn[PGconn]:
     return conn
 
 
-def execute(pgconn: PGconn) -> PQGen[List[PGresult]]:
+def _execute(pgconn: PGconn) -> PQGen[List[PGresult]]:
     """
     Generator sending a query and returning results without blocking.
 
@@ -72,12 +73,12 @@ def execute(pgconn: PGconn) -> PQGen[List[PGresult]]:
     Return the list of results returned by the database (whether success
     or error).
     """
-    yield from send(pgconn)
-    rv = yield from fetch_many(pgconn)
+    yield from _send(pgconn)
+    rv = yield from _fetch_many(pgconn)
     return rv
 
 
-def send(pgconn: PGconn) -> PQGen[None]:
+def _send(pgconn: PGconn) -> PQGen[None]:
     """
     Generator to send a query to the server without blocking.
 
@@ -100,7 +101,7 @@ def send(pgconn: PGconn) -> PQGen[None]:
             pgconn.consume_input()
 
 
-def fetch_many(pgconn: PGconn) -> PQGen[List[PGresult]]:
+def _fetch_many(pgconn: PGconn) -> PQGen[List[PGresult]]:
     """
     Generator retrieving results from the database without blocking.
 
@@ -112,7 +113,7 @@ def fetch_many(pgconn: PGconn) -> PQGen[List[PGresult]]:
     """
     results: List[PGresult] = []
     while 1:
-        res = yield from fetch(pgconn)
+        res = yield from _fetch(pgconn)
         if not res:
             break
 
@@ -125,7 +126,7 @@ def fetch_many(pgconn: PGconn) -> PQGen[List[PGresult]]:
     return results
 
 
-def fetch(pgconn: PGconn) -> PQGen[Optional[PGresult]]:
+def _fetch(pgconn: PGconn) -> PQGen[Optional[PGresult]]:
     """
     Generator retrieving a single result from the database without blocking.
 
@@ -190,7 +191,7 @@ def copy_from(pgconn: PGconn) -> PQGen[Union[memoryview, PGresult]]:
         return data
 
     # Retrieve the final result of copy
-    results = yield from fetch_many(pgconn)
+    results = yield from _fetch_many(pgconn)
     if len(results) > 1:
         # TODO: too brutal? Copy worked.
         raise e.ProgrammingError("you cannot mix COPY with other operations")
@@ -226,9 +227,25 @@ def copy_end(pgconn: PGconn, error: Optional[bytes]) -> PQGen[PGresult]:
             break
 
     # Retrieve the final result of copy
-    (result,) = yield from fetch_many(pgconn)
+    (result,) = yield from _fetch_many(pgconn)
     if result.status != ExecStatus.COMMAND_OK:
         encoding = pgconn_encoding(pgconn)
         raise e.error_from_result(result, encoding=encoding)
 
     return result
+
+
+# Override functions with fast versions if available
+if _psycopg:
+    connect = _psycopg.connect
+    execute = _psycopg.execute
+    send = _psycopg.send
+    fetch_many = _psycopg.fetch_many
+    fetch = _psycopg.fetch
+
+else:
+    connect = _connect
+    execute = _execute
+    send = _send
+    fetch_many = _fetch_many
+    fetch = _fetch
index 6bf669980b97d02645a906e51f91be79bf1ddddb..bf57069cd4bd207569a007356faa84f82a9cb055 100644 (file)
@@ -13,7 +13,8 @@ from . import sql
 from . import errors as e
 from .abc import ConnectionType, Query, Params, PQGen
 from .rows import Row, RowFactory, AsyncRowFactory
-from .cursor import BaseCursor, Cursor, execute
+from .cursor import BaseCursor, Cursor
+from .generators import execute
 from .cursor_async import AsyncCursor
 
 if TYPE_CHECKING:
index 2df4d4705837320997cf0c0a2d8ee9ac6ee03ca7..794ef0f89ec6ca2db34de070ff23a6561b7a314b 100644 (file)
@@ -17,7 +17,7 @@ def test_connect(monkeypatch, dsn, args, kwargs, want_conninfo):
     # Details of the params manipulation are in test_conninfo.
     import psycopg.connection
 
-    orig_connect = psycopg.connection.connect
+    orig_connect = psycopg.connection.connect  # type: ignore
 
     got_conninfo = None