]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Import pq enums in modules where they are used
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 24 Nov 2020 01:18:57 +0000 (01:18 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 24 Nov 2020 01:18:57 +0000 (01:18 +0000)
psycopg3/psycopg3/_transform.py
psycopg3/psycopg3/connection.py
psycopg3/psycopg3/cursor.py
psycopg3/psycopg3/generators.py
psycopg3/psycopg3/types/composite.py

index 713fe222758dffdc8ed922cd54bfac56fa3686d2..360a945a420dd8c54412d02f1af40540a053f279 100644 (file)
@@ -8,7 +8,7 @@ from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
 from typing import TYPE_CHECKING
 
 from . import errors as e
-from . import pq
+from .pq import Format
 from .oids import builtins, INVALID_OID
 from .proto import AdaptContext, DumpersMap
 from .proto import LoadFunc, LoadersMap
@@ -16,9 +16,9 @@ from .cursor import BaseCursor
 from .connection import BaseConnection
 
 if TYPE_CHECKING:
+    from .pq.proto import PGresult
     from .adapt import Dumper, Loader
 
-Format = pq.Format
 TEXT_OID = builtins["text"].oid
 
 
@@ -37,7 +37,7 @@ class Transformer:
         self._dumpers_maps: List[DumpersMap] = []
         self._loaders_maps: List[LoadersMap] = []
         self._setup_context(context)
-        self.pgresult = None
+        self._pgresult: Optional["PGresult"] = None
 
         # mapping class, fmt -> Dumper instance
         self._dumpers_cache: Dict[Tuple[type, Format], "Dumper"] = {}
@@ -107,11 +107,11 @@ class Transformer:
         return self._encoding
 
     @property
-    def pgresult(self) -> Optional[pq.proto.PGresult]:
+    def pgresult(self) -> Optional["PGresult"]:
         return self._pgresult
 
     @pgresult.setter
-    def pgresult(self, result: Optional[pq.proto.PGresult]) -> None:
+    def pgresult(self, result: Optional["PGresult"]) -> None:
         self._pgresult = result
         rc = self._row_loaders = []
 
@@ -184,7 +184,7 @@ class Transformer:
         )
 
     def load_row(self, row: int) -> Optional[Tuple[Any, ...]]:
-        res = self.pgresult
+        res = self._pgresult
         if not res:
             return None
 
index 2d822fb731706ea38041506762f5b4c5d841f37d..48097661f2596c4419aab6c80055d1ee09c070d8 100644 (file)
@@ -24,7 +24,7 @@ from . import pq
 from . import cursor
 from . import errors as e
 from . import encodings
-from .pq import TransactionStatus, ExecStatus
+from .pq import TransactionStatus, ExecStatus, Format
 from .sql import Composable
 from .proto import DumpersMap, LoadersMap, PQGen, RV, Query
 from .waiting import wait, wait_async
@@ -274,7 +274,7 @@ class Connection(BaseConnection):
         self.pgconn.finish()
 
     def cursor(
-        self, name: str = "", format: pq.Format = pq.Format.TEXT
+        self, name: str = "", format: Format = Format.TEXT
     ) -> "psycopg3.Cursor":
         """
         Return a new `Cursor` to send commands and queries to the connection.
@@ -432,7 +432,7 @@ class AsyncConnection(BaseConnection):
         self.pgconn.finish()
 
     async def cursor(
-        self, name: str = "", format: pq.Format = pq.Format.TEXT
+        self, name: str = "", format: Format = Format.TEXT
     ) -> "psycopg3.AsyncCursor":
         """
         Return a new `AsyncCursor` to send commands and queries to the connection.
index 3a713edd4edbe84a3f8a4c5a51e8ae8b4be07d61..7bd54d20d36cd8cbd86e74429942cad983e26e44 100644 (file)
@@ -13,6 +13,7 @@ from contextlib import contextmanager
 
 from . import errors as e
 from . import pq
+from .pq import ConnStatus, ExecStatus, Format
 from .oids import builtins
 from .copy import Copy, AsyncCopy
 from .proto import ConnectionType, Query, Params, DumpersMap, LoadersMap, PQGen
@@ -166,7 +167,7 @@ class BaseCursor(Generic[ConnectionType]):
     def __init__(
         self,
         connection: ConnectionType,
-        format: pq.Format = pq.Format.TEXT,
+        format: Format = Format.TEXT,
     ):
         self._conn = connection
         self.format = format
@@ -225,12 +226,12 @@ class BaseCursor(Generic[ConnectionType]):
     @property
     def description(self) -> Optional[List[Column]]:
         """
-        A list of `Column` object describing the current resultset.
+        A list of `Column` objects describing the current resultset.
 
         `!None` if the current resultset didn't return tuples.
         """
         res = self.pgresult
-        if not res or res.status != self.ExecStatus.TUPLES_OK:
+        if not res or res.status != ExecStatus.TUPLES_OK:
             return None
         encoding = self._conn.client_encoding
         return [Column(res, i, encoding) for i in range(res.nfields)]
@@ -274,7 +275,7 @@ class BaseCursor(Generic[ConnectionType]):
         if self._conn.closed:
             raise e.InterfaceError("the connection is closed")
 
-        if self._conn.pgconn.status != pq.ConnStatus.OK:
+        if self._conn.pgconn.status != ConnStatus.OK:
             raise e.InterfaceError(
                 f"cannot execute operations: the connection is"
                 f" in status {self._conn.pgconn.status}"
@@ -292,7 +293,7 @@ class BaseCursor(Generic[ConnectionType]):
         pgq = PostgresQuery(self._transformer)
         pgq.convert(query, params)
 
-        if pgq.params or no_pqexec or self.format == pq.Format.BINARY:
+        if pgq.params or no_pqexec or self.format == Format.BINARY:
             self._query = pgq.query
             self._params = pgq.params
             self._conn.pgconn.send_query_params(
@@ -309,6 +310,17 @@ class BaseCursor(Generic[ConnectionType]):
             self._params = None
             self._conn.pgconn.send_query(pgq.query)
 
+    _status_ok = {
+        ExecStatus.TUPLES_OK,
+        ExecStatus.COMMAND_OK,
+        ExecStatus.EMPTY_QUERY,
+    }
+    _status_copy = {
+        ExecStatus.COPY_IN,
+        ExecStatus.COPY_OUT,
+        ExecStatus.COPY_BOTH,
+    }
+
     def _execute_results(self, results: Sequence["PGresult"]) -> None:
         """
         Implement part of execute() after waiting common to sync and async
@@ -316,9 +328,8 @@ class BaseCursor(Generic[ConnectionType]):
         if not results:
             raise e.InternalError("got no result from the query")
 
-        S = self.ExecStatus
         statuses = {res.status for res in results}
-        badstats = statuses - {S.TUPLES_OK, S.COMMAND_OK, S.EMPTY_QUERY}
+        badstats = statuses - self._status_ok
         if not badstats:
             self._results = list(results)
             self.pgresult = results[0]
@@ -331,12 +342,11 @@ class BaseCursor(Generic[ConnectionType]):
 
             return
 
-        if results[-1].status == S.FATAL_ERROR:
+        if results[-1].status == ExecStatus.FATAL_ERROR:
             raise e.error_from_result(
                 results[-1], encoding=self._conn.client_encoding
             )
-
-        elif badstats & {S.COPY_IN, S.COPY_OUT, S.COPY_BOTH}:
+        elif badstats & self._status_copy:
             raise e.ProgrammingError(
                 "COPY cannot be used with execute(); use copy() insead"
             )
@@ -373,7 +383,7 @@ class BaseCursor(Generic[ConnectionType]):
         res = self.pgresult
         if not res:
             raise e.ProgrammingError("no result available")
-        elif res.status != self.ExecStatus.TUPLES_OK:
+        elif res.status != ExecStatus.TUPLES_OK:
             raise e.ProgrammingError(
                 "the last operation didn't produce a result"
             )
@@ -389,16 +399,16 @@ class BaseCursor(Generic[ConnectionType]):
 
         result = results[0]
         status = result.status
-        if status in (pq.ExecStatus.COPY_IN, pq.ExecStatus.COPY_OUT):
+        if status in (ExecStatus.COPY_IN, ExecStatus.COPY_OUT):
             return
-        elif status == pq.ExecStatus.FATAL_ERROR:
+        elif status == ExecStatus.FATAL_ERROR:
             raise e.error_from_result(
                 result, encoding=self._conn.client_encoding
             )
         else:
             raise e.ProgrammingError(
                 "copy() should be used only with COPY ... TO STDOUT or COPY ..."
-                f" FROM STDIN statements, got {pq.ExecStatus(status).name}"
+                f" FROM STDIN statements, got {ExecStatus(status).name}"
             )
 
 
@@ -449,7 +459,7 @@ class Cursor(BaseCursor["Connection"]):
                     pgq = self._send_prepare(b"", query, params)
                     gen = execute(self._conn.pgconn)
                     (result,) = self._conn.wait(gen)
-                    if result.status == self.ExecStatus.FATAL_ERROR:
+                    if result.status == ExecStatus.FATAL_ERROR:
                         raise e.error_from_result(
                             result, encoding=self._conn.client_encoding
                         )
@@ -578,7 +588,7 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
                     pgq = self._send_prepare(b"", query, params)
                     gen = execute(self._conn.pgconn)
                     (result,) = await self._conn.wait(gen)
-                    if result.status == self.ExecStatus.FATAL_ERROR:
+                    if result.status == ExecStatus.FATAL_ERROR:
                         raise e.error_from_result(
                             result, encoding=self._conn.client_encoding
                         )
index 96944206d282a6d55535b51fecbda05bb5be4bad..933a9e21e8a8bdcdfdf3e3ac925eafb36db477bc 100644 (file)
@@ -20,6 +20,7 @@ from typing import List, Optional, Union
 
 from . import pq
 from . import errors as e
+from .pq import ConnStatus, PollingStatus, ExecStatus
 from .proto import PQGen
 from .waiting import Wait, Ready
 from .encodings import py_codecs
@@ -36,20 +37,20 @@ def connect(conninfo: str) -> PQGen[PGconn]:
     conn = pq.PGconn.connect_start(conninfo.encode("utf8"))
     logger.debug("connection started, status %s", conn.status.name)
     while 1:
-        if conn.status == pq.ConnStatus.BAD:
+        if conn.status == ConnStatus.BAD:
             raise e.OperationalError(
                 f"connection is bad: {pq.error_message(conn)}"
             )
 
         status = conn.connect_poll()
         logger.debug("connection polled, status %s", conn.status.name)
-        if status == pq.PollingStatus.OK:
+        if status == PollingStatus.OK:
             break
-        elif status == pq.PollingStatus.READING:
+        elif status == PollingStatus.READING:
             yield conn.socket, Wait.R
-        elif status == pq.PollingStatus.WRITING:
+        elif status == PollingStatus.WRITING:
             yield conn.socket, Wait.W
-        elif status == pq.PollingStatus.FAILED:
+        elif status == PollingStatus.FAILED:
             raise e.OperationalError(
                 f"connection failed: {pq.error_message(conn)}"
             )
@@ -110,7 +111,6 @@ def fetch(pgconn: PGconn) -> PQGen[List[PGresult]]:
     Return the list of results returned by the database (whether success
     or error).
     """
-    S = pq.ExecStatus
     results: List[PGresult] = []
     while 1:
         pgconn.consume_input()
@@ -130,7 +130,7 @@ def fetch(pgconn: PGconn) -> PQGen[List[PGresult]]:
         if res is None:
             break
         results.append(res)
-        if res.status in (S.COPY_IN, S.COPY_OUT, S.COPY_BOTH):
+        if res.status in _copy_statuses:
             # After entering copy mode the libpq will create a phony result
             # for every request so let's break the endless loop.
             break
@@ -138,6 +138,13 @@ def fetch(pgconn: PGconn) -> PQGen[List[PGresult]]:
     return results
 
 
+_copy_statuses = (
+    ExecStatus.COPY_IN,
+    ExecStatus.COPY_OUT,
+    ExecStatus.COPY_BOTH,
+)
+
+
 def notifies(pgconn: PGconn) -> PQGen[List[pq.PGnotify]]:
     yield pgconn.socket, Wait.R
     pgconn.consume_input()
@@ -169,7 +176,7 @@ def copy_from(pgconn: PGconn) -> PQGen[Union[bytes, PGresult]]:
 
     # Retrieve the final result of copy
     (result,) = yield from fetch(pgconn)
-    if result.status != pq.ExecStatus.COMMAND_OK:
+    if result.status != ExecStatus.COMMAND_OK:
         encoding = py_codecs.get(
             pgconn.parameter_status(b"client_encoding") or "", "utf-8"
         )
@@ -198,7 +205,7 @@ def copy_end(pgconn: PGconn, error: Optional[bytes]) -> PQGen[PGresult]:
 
     # Retrieve the final result of copy
     (result,) = yield from fetch(pgconn)
-    if result.status != pq.ExecStatus.COMMAND_OK:
+    if result.status != ExecStatus.COMMAND_OK:
         encoding = py_codecs.get(
             pgconn.parameter_status(b"client_encoding") or "", "utf-8"
         )
index f27505db9683097932b1cacdde8014732d436368..320fa669ede30955aeb3041dfda7976567c926e3 100644 (file)
@@ -8,7 +8,6 @@ from collections import namedtuple
 from typing import Any, Callable, Iterator, Sequence, Tuple, Type
 from typing import Optional, TYPE_CHECKING
 
-from .. import pq
 from ..oids import builtins, TypeInfo
 from ..adapt import Format, Dumper, Loader, Transformer
 from ..proto import AdaptContext
@@ -45,7 +44,7 @@ class CompositeTypeInfo(TypeInfo):
 
 
 def fetch_info(conn: "Connection", name: str) -> Optional[CompositeTypeInfo]:
-    cur = conn.cursor(format=pq.Format.BINARY)
+    cur = conn.cursor(format=Format.BINARY)
     cur.execute(_type_info_query, {"name": name})
     rec = cur.fetchone()
     return CompositeTypeInfo._from_record(rec)
@@ -54,7 +53,7 @@ def fetch_info(conn: "Connection", name: str) -> Optional[CompositeTypeInfo]:
 async def fetch_info_async(
     conn: "AsyncConnection", name: str
 ) -> Optional[CompositeTypeInfo]:
-    cur = await conn.cursor(format=pq.Format.BINARY)
+    cur = await conn.cursor(format=Format.BINARY)
     await cur.execute(_type_info_query, {"name": name})
     rec = await cur.fetchone()
     return CompositeTypeInfo._from_record(rec)