]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Cursor.description can be pickled
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 1 Dec 2020 02:54:56 +0000 (02:54 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 1 Dec 2020 02:54:56 +0000 (02:54 +0000)
psycopg3/psycopg3/__init__.py
psycopg3/psycopg3/_column.py [new file with mode: 0644]
psycopg3/psycopg3/connection.py
psycopg3/psycopg3/cursor.py
tests/test_cursor.py

index 05674fbcd8e532a0fe112e5dc18523ffef199587..2aa334a77bb0bca34faef9b0ca40308c22cbf37d 100644 (file)
@@ -6,10 +6,11 @@ psycopg3 -- PostgreSQL database adapter for Python
 
 from . import pq
 from .copy import Copy, AsyncCopy
-from .cursor import AsyncCursor, Cursor, Column
+from .cursor import AsyncCursor, Cursor
 from .errors import Warning, Error, InterfaceError, DatabaseError
 from .errors import DataError, OperationalError, IntegrityError
 from .errors import InternalError, ProgrammingError, NotSupportedError
+from ._column import Column
 from .connection import AsyncConnection, Connection, Notify
 from .transaction import Rollback, Transaction, AsyncTransaction
 
diff --git a/psycopg3/psycopg3/_column.py b/psycopg3/psycopg3/_column.py
new file mode 100644 (file)
index 0000000..b9ae001
--- /dev/null
@@ -0,0 +1,138 @@
+from typing import Any, NamedTuple, Optional, Sequence, TYPE_CHECKING
+from operator import attrgetter
+
+from . import errors as e
+from .oids import builtins
+
+if TYPE_CHECKING:
+    from .cursor import BaseCursor
+
+
+class ColumnData(NamedTuple):
+    ftype: int
+    fmod: int
+    fsize: int
+
+
+class Column(Sequence[Any]):
+
+    __module__ = "psycopg3"
+
+    def __init__(self, cursor: "BaseCursor[Any]", index: int):
+        res = cursor.pgresult
+        assert res
+
+        fname = res.fname(index)
+        if not fname:
+            raise e.InterfaceError(f"no name available for column {index}")
+
+        self._name = fname.decode(cursor.connection.client_encoding)
+
+        self._data = ColumnData(
+            ftype=res.ftype(index),
+            fmod=res.fmod(index),
+            fsize=res.fsize(index),
+        )
+
+    _attrs = tuple(
+        attrgetter(attr)
+        for attr in """
+            name type_code display_size internal_size precision scale null_ok
+            """.split()
+    )
+
+    def __repr__(self) -> str:
+        return f"<Column {self.name}, type: {self._type_display()}>"
+
+    def __len__(self) -> int:
+        return 7
+
+    def _type_display(self) -> str:
+        parts = []
+        t = builtins.get(self.type_code)
+        parts.append(t.name if t else str(self.type_code))
+
+        mod1 = self.precision
+        if mod1 is None:
+            mod1 = self.display_size
+        if mod1:
+            parts.append(f"({mod1}")
+            if self.scale:
+                parts.append(f", {self.scale}")
+            parts.append(")")
+
+        return "".join(parts)
+
+    def __getitem__(self, index: Any) -> Any:
+        if isinstance(index, slice):
+            return tuple(getter(self) for getter in self._attrs[index])
+        else:
+            return self._attrs[index](self)
+
+    @property
+    def name(self) -> str:
+        """The name of the column."""
+        return self._name
+
+    @property
+    def type_code(self) -> int:
+        """The numeric OID of the column."""
+        return self._data.ftype
+
+    @property
+    def display_size(self) -> Optional[int]:
+        """The field size, for :sql:`varchar(n)`, None otherwise."""
+        t = builtins.get(self.type_code)
+        if not t:
+            return None
+
+        if t.name in ("varchar", "char"):
+            fmod = self._data.fmod
+            if fmod >= 0:
+                return fmod - 4
+
+        return None
+
+    @property
+    def internal_size(self) -> Optional[int]:
+        """The interal field size for fixed-size types, None otherwise."""
+        fsize = self._data.fsize
+        return fsize if fsize >= 0 else None
+
+    @property
+    def precision(self) -> Optional[int]:
+        """The number of digits for fixed precision types."""
+        t = builtins.get(self.type_code)
+        if not t:
+            return None
+
+        dttypes = ("time", "timetz", "timestamp", "timestamptz", "interval")
+        if t.name == "numeric":
+            fmod = self._data.fmod
+            if fmod >= 0:
+                return fmod >> 16
+
+        elif t.name in dttypes:
+            fmod = self._data.fmod
+            if fmod >= 0:
+                return fmod & 0xFFFF
+
+        return None
+
+    @property
+    def scale(self) -> Optional[int]:
+        """The number of digits after the decimal point if available.
+
+        TODO: probably better than precision for datetime objects? review.
+        """
+        if self.type_code == builtins["numeric"].oid:
+            fmod = self._data.fmod - 4
+            if fmod >= 0:
+                return fmod & 0xFFFF
+
+        return None
+
+    @property
+    def null_ok(self) -> Optional[bool]:
+        """Always `!None`"""
+        return None
index 8813dcb3404dcb0bbf0bacad64e61539f040a5b5..f370f0be1733b7708cb5835c54b8a9c5afb8b9f9 100644 (file)
@@ -10,7 +10,7 @@ import logging
 import threading
 from types import TracebackType
 from typing import Any, AsyncIterator, Callable, Iterator, List, NamedTuple
-from typing import Optional, Type, TYPE_CHECKING, Union
+from typing import Optional, Type, TYPE_CHECKING
 from weakref import ref, ReferenceType
 from functools import partial
 from contextlib import contextmanager
@@ -39,7 +39,7 @@ connect: Callable[[str], PQGen["PGconn"]]
 execute: Callable[["PGconn"], PQGen[List["PGresult"]]]
 
 if TYPE_CHECKING:
-    from .cursor import AsyncCursor, Cursor
+    from .cursor import AsyncCursor, BaseCursor, Cursor
     from .pq.proto import PGconn, PGresult
 
 if pq.__impl__ == "c":
@@ -98,7 +98,7 @@ class BaseConnection:
     ConnStatus = pq.ConnStatus
     TransactionStatus = pq.TransactionStatus
 
-    cursor_factory: Union[Type["Cursor"], Type["AsyncCursor"]]
+    cursor_factory: Type["BaseCursor[Any]"]
 
     def __init__(self, pgconn: "PGconn"):
         self.pgconn = pgconn  # TODO: document this
index 7c424d5c8f8c4ec7e6da5796670ab83e2081d9d5..1f4b854b29d189086f8e9a2b6a7a63018f4d0c22 100644 (file)
@@ -8,15 +8,14 @@ import sys
 from types import TracebackType
 from typing import Any, AsyncIterator, Callable, Generic, Iterator, List
 from typing import Optional, Sequence, Type, TYPE_CHECKING
-from operator import attrgetter
 from contextlib import contextmanager
 
 from . import errors as e
 from . import pq
 from .pq import ConnStatus, ExecStatus, Format
-from .oids import builtins
 from .copy import Copy, AsyncCopy
 from .proto import ConnectionType, Query, Params, DumpersMap, LoadersMap, PQGen
+from ._column import Column
 from ._queries import PostgresQuery
 
 if sys.version_info >= (3, 7):
@@ -42,125 +41,6 @@ else:
     execute = generators.execute
 
 
-class Column(Sequence[Any]):
-
-    __module__ = "psycopg3"
-
-    def __init__(self, pgresult: "PGresult", index: int, encoding: str):
-        self._pgresult = pgresult
-        self._index = index
-        self._encoding = encoding
-
-    _attrs = tuple(
-        attrgetter(attr)
-        for attr in """
-            name type_code display_size internal_size precision scale null_ok
-            """.split()
-    )
-
-    def __repr__(self) -> str:
-        return f"<Column {self.name}, type: {self._type_display()}>"
-
-    def __len__(self) -> int:
-        return 7
-
-    def _type_display(self) -> str:
-        parts = []
-        t = builtins.get(self.type_code)
-        parts.append(t.name if t else str(self.type_code))
-
-        mod1 = self.precision
-        if mod1 is None:
-            mod1 = self.display_size
-        if mod1:
-            parts.append(f"({mod1}")
-            if self.scale:
-                parts.append(f", {self.scale}")
-            parts.append(")")
-
-        return "".join(parts)
-
-    def __getitem__(self, index: Any) -> Any:
-        if isinstance(index, slice):
-            return tuple(getter(self) for getter in self._attrs[index])
-        else:
-            return self._attrs[index](self)
-
-    @property
-    def name(self) -> str:
-        """The name of the column."""
-        rv = self._pgresult.fname(self._index)
-        if rv:
-            return rv.decode(self._encoding)
-        else:
-            raise e.InterfaceError(
-                f"no name available for column {self._index}"
-            )
-
-    @property
-    def type_code(self) -> int:
-        """The numeric OID of the column."""
-        return self._pgresult.ftype(self._index)
-
-    @property
-    def display_size(self) -> Optional[int]:
-        """The field size, for :sql:`varchar(n)`, None otherwise."""
-        t = builtins.get(self.type_code)
-        if not t:
-            return None
-
-        if t.name in ("varchar", "char"):
-            fmod = self._pgresult.fmod(self._index)
-            if fmod >= 0:
-                return fmod - 4
-
-        return None
-
-    @property
-    def internal_size(self) -> Optional[int]:
-        """The interal field size for fixed-size types, None otherwise."""
-        fsize = self._pgresult.fsize(self._index)
-        return fsize if fsize >= 0 else None
-
-    @property
-    def precision(self) -> Optional[int]:
-        """The number of digits for fixed precision types."""
-        t = builtins.get(self.type_code)
-        if not t:
-            return None
-
-        dttypes = ("time", "timetz", "timestamp", "timestamptz", "interval")
-        if t.name == "numeric":
-            fmod = self._pgresult.fmod(self._index)
-            if fmod >= 0:
-                return fmod >> 16
-
-        elif t.name in dttypes:
-            fmod = self._pgresult.fmod(self._index)
-            if fmod >= 0:
-                return fmod & 0xFFFF
-
-        return None
-
-    @property
-    def scale(self) -> Optional[int]:
-        """The number of digits after the decimal point if available.
-
-        TODO: probably better than precision for datetime objects? review.
-        """
-        if self.type_code == builtins["numeric"].oid:
-            fmod = self._pgresult.fmod(self._index) - 4
-            if fmod >= 0:
-                return fmod & 0xFFFF
-
-        return None
-
-    @property
-    def null_ok(self) -> Optional[bool]:
-        """Always `!None`"""
-        return None
-
-
 class BaseCursor(Generic[ConnectionType]):
     ExecStatus = pq.ExecStatus
 
@@ -236,8 +116,7 @@ class BaseCursor(Generic[ConnectionType]):
         res = self.pgresult
         if not res or res.status != ExecStatus.TUPLES_OK:
             return None
-        encoding = self._conn.client_encoding
-        return [Column(res, i, encoding) for i in range(res.nfields)]
+        return [Column(self, i) for i in range(res.nfields)]
 
     @property
     def rowcount(self) -> int:
index 36ca400512d3800de85943565fa525ef36b10349..45750143e69061db454caadd806ee469dfa56e8f 100644 (file)
@@ -1,7 +1,9 @@
 import gc
-import pytest
+import pickle
 import weakref
 
+import pytest
+
 import psycopg3
 from psycopg3.oids import builtins
 
@@ -343,3 +345,17 @@ class TestColumn:
         assert col.scale == scale
         assert col.display_size == dsize
         assert col.internal_size == isize
+
+    def test_pickle(self, conn):
+        curs = conn.cursor()
+        curs.execute(
+            """select
+            3.14::decimal(10,2) as pi,
+            'hello'::text as hi,
+            '2010-02-18'::date as now
+            """
+        )
+        description = curs.description
+        pickled = pickle.dumps(description, pickle.HIGHEST_PROTOCOL)
+        unpickled = pickle.loads(pickled)
+        assert [tuple(d) for d in description] == [tuple(d) for d in unpickled]