From: Daniele Varrazzo Date: Tue, 1 Dec 2020 02:54:56 +0000 (+0000) Subject: Cursor.description can be pickled X-Git-Tag: 3.0.dev0~296 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=18b799663afcc0758790a60dec4251103b3c2c64;p=thirdparty%2Fpsycopg.git Cursor.description can be pickled --- diff --git a/psycopg3/psycopg3/__init__.py b/psycopg3/psycopg3/__init__.py index 05674fbcd..2aa334a77 100644 --- a/psycopg3/psycopg3/__init__.py +++ b/psycopg3/psycopg3/__init__.py @@ -6,10 +6,11 @@ psycopg3 -- PostgreSQL database adapter for Python from . import pq from .copy import Copy, AsyncCopy -from .cursor import AsyncCursor, Cursor, Column +from .cursor import AsyncCursor, Cursor from .errors import Warning, Error, InterfaceError, DatabaseError from .errors import DataError, OperationalError, IntegrityError from .errors import InternalError, ProgrammingError, NotSupportedError +from ._column import Column from .connection import AsyncConnection, Connection, Notify from .transaction import Rollback, Transaction, AsyncTransaction diff --git a/psycopg3/psycopg3/_column.py b/psycopg3/psycopg3/_column.py new file mode 100644 index 000000000..b9ae0013b --- /dev/null +++ b/psycopg3/psycopg3/_column.py @@ -0,0 +1,138 @@ +from typing import Any, NamedTuple, Optional, Sequence, TYPE_CHECKING +from operator import attrgetter + +from . import errors as e +from .oids import builtins + +if TYPE_CHECKING: + from .cursor import BaseCursor + + +class ColumnData(NamedTuple): + ftype: int + fmod: int + fsize: int + + +class Column(Sequence[Any]): + + __module__ = "psycopg3" + + def __init__(self, cursor: "BaseCursor[Any]", index: int): + res = cursor.pgresult + assert res + + fname = res.fname(index) + if not fname: + raise e.InterfaceError(f"no name available for column {index}") + + self._name = fname.decode(cursor.connection.client_encoding) + + self._data = ColumnData( + ftype=res.ftype(index), + fmod=res.fmod(index), + fsize=res.fsize(index), + ) + + _attrs = tuple( + attrgetter(attr) + for attr in """ + name type_code display_size internal_size precision scale null_ok + """.split() + ) + + def __repr__(self) -> str: + return f"" + + def __len__(self) -> int: + return 7 + + def _type_display(self) -> str: + parts = [] + t = builtins.get(self.type_code) + parts.append(t.name if t else str(self.type_code)) + + mod1 = self.precision + if mod1 is None: + mod1 = self.display_size + if mod1: + parts.append(f"({mod1}") + if self.scale: + parts.append(f", {self.scale}") + parts.append(")") + + return "".join(parts) + + def __getitem__(self, index: Any) -> Any: + if isinstance(index, slice): + return tuple(getter(self) for getter in self._attrs[index]) + else: + return self._attrs[index](self) + + @property + def name(self) -> str: + """The name of the column.""" + return self._name + + @property + def type_code(self) -> int: + """The numeric OID of the column.""" + return self._data.ftype + + @property + def display_size(self) -> Optional[int]: + """The field size, for :sql:`varchar(n)`, None otherwise.""" + t = builtins.get(self.type_code) + if not t: + return None + + if t.name in ("varchar", "char"): + fmod = self._data.fmod + if fmod >= 0: + return fmod - 4 + + return None + + @property + def internal_size(self) -> Optional[int]: + """The interal field size for fixed-size types, None otherwise.""" + fsize = self._data.fsize + return fsize if fsize >= 0 else None + + @property + def precision(self) -> Optional[int]: + """The number of digits for fixed precision types.""" + t = builtins.get(self.type_code) + if not t: + return None + + dttypes = ("time", "timetz", "timestamp", "timestamptz", "interval") + if t.name == "numeric": + fmod = self._data.fmod + if fmod >= 0: + return fmod >> 16 + + elif t.name in dttypes: + fmod = self._data.fmod + if fmod >= 0: + return fmod & 0xFFFF + + return None + + @property + def scale(self) -> Optional[int]: + """The number of digits after the decimal point if available. + + TODO: probably better than precision for datetime objects? review. + """ + if self.type_code == builtins["numeric"].oid: + fmod = self._data.fmod - 4 + if fmod >= 0: + return fmod & 0xFFFF + + return None + + @property + def null_ok(self) -> Optional[bool]: + """Always `!None`""" + return None diff --git a/psycopg3/psycopg3/connection.py b/psycopg3/psycopg3/connection.py index 8813dcb34..f370f0be1 100644 --- a/psycopg3/psycopg3/connection.py +++ b/psycopg3/psycopg3/connection.py @@ -10,7 +10,7 @@ import logging import threading from types import TracebackType from typing import Any, AsyncIterator, Callable, Iterator, List, NamedTuple -from typing import Optional, Type, TYPE_CHECKING, Union +from typing import Optional, Type, TYPE_CHECKING from weakref import ref, ReferenceType from functools import partial from contextlib import contextmanager @@ -39,7 +39,7 @@ connect: Callable[[str], PQGen["PGconn"]] execute: Callable[["PGconn"], PQGen[List["PGresult"]]] if TYPE_CHECKING: - from .cursor import AsyncCursor, Cursor + from .cursor import AsyncCursor, BaseCursor, Cursor from .pq.proto import PGconn, PGresult if pq.__impl__ == "c": @@ -98,7 +98,7 @@ class BaseConnection: ConnStatus = pq.ConnStatus TransactionStatus = pq.TransactionStatus - cursor_factory: Union[Type["Cursor"], Type["AsyncCursor"]] + cursor_factory: Type["BaseCursor[Any]"] def __init__(self, pgconn: "PGconn"): self.pgconn = pgconn # TODO: document this diff --git a/psycopg3/psycopg3/cursor.py b/psycopg3/psycopg3/cursor.py index 7c424d5c8..1f4b854b2 100644 --- a/psycopg3/psycopg3/cursor.py +++ b/psycopg3/psycopg3/cursor.py @@ -8,15 +8,14 @@ import sys from types import TracebackType from typing import Any, AsyncIterator, Callable, Generic, Iterator, List from typing import Optional, Sequence, Type, TYPE_CHECKING -from operator import attrgetter from contextlib import contextmanager from . import errors as e from . import pq from .pq import ConnStatus, ExecStatus, Format -from .oids import builtins from .copy import Copy, AsyncCopy from .proto import ConnectionType, Query, Params, DumpersMap, LoadersMap, PQGen +from ._column import Column from ._queries import PostgresQuery if sys.version_info >= (3, 7): @@ -42,125 +41,6 @@ else: execute = generators.execute -class Column(Sequence[Any]): - - __module__ = "psycopg3" - - def __init__(self, pgresult: "PGresult", index: int, encoding: str): - self._pgresult = pgresult - self._index = index - self._encoding = encoding - - _attrs = tuple( - attrgetter(attr) - for attr in """ - name type_code display_size internal_size precision scale null_ok - """.split() - ) - - def __repr__(self) -> str: - return f"" - - def __len__(self) -> int: - return 7 - - def _type_display(self) -> str: - parts = [] - t = builtins.get(self.type_code) - parts.append(t.name if t else str(self.type_code)) - - mod1 = self.precision - if mod1 is None: - mod1 = self.display_size - if mod1: - parts.append(f"({mod1}") - if self.scale: - parts.append(f", {self.scale}") - parts.append(")") - - return "".join(parts) - - def __getitem__(self, index: Any) -> Any: - if isinstance(index, slice): - return tuple(getter(self) for getter in self._attrs[index]) - else: - return self._attrs[index](self) - - @property - def name(self) -> str: - """The name of the column.""" - rv = self._pgresult.fname(self._index) - if rv: - return rv.decode(self._encoding) - else: - raise e.InterfaceError( - f"no name available for column {self._index}" - ) - - @property - def type_code(self) -> int: - """The numeric OID of the column.""" - return self._pgresult.ftype(self._index) - - @property - def display_size(self) -> Optional[int]: - """The field size, for :sql:`varchar(n)`, None otherwise.""" - t = builtins.get(self.type_code) - if not t: - return None - - if t.name in ("varchar", "char"): - fmod = self._pgresult.fmod(self._index) - if fmod >= 0: - return fmod - 4 - - return None - - @property - def internal_size(self) -> Optional[int]: - """The interal field size for fixed-size types, None otherwise.""" - fsize = self._pgresult.fsize(self._index) - return fsize if fsize >= 0 else None - - @property - def precision(self) -> Optional[int]: - """The number of digits for fixed precision types.""" - t = builtins.get(self.type_code) - if not t: - return None - - dttypes = ("time", "timetz", "timestamp", "timestamptz", "interval") - if t.name == "numeric": - fmod = self._pgresult.fmod(self._index) - if fmod >= 0: - return fmod >> 16 - - elif t.name in dttypes: - fmod = self._pgresult.fmod(self._index) - if fmod >= 0: - return fmod & 0xFFFF - - return None - - @property - def scale(self) -> Optional[int]: - """The number of digits after the decimal point if available. - - TODO: probably better than precision for datetime objects? review. - """ - if self.type_code == builtins["numeric"].oid: - fmod = self._pgresult.fmod(self._index) - 4 - if fmod >= 0: - return fmod & 0xFFFF - - return None - - @property - def null_ok(self) -> Optional[bool]: - """Always `!None`""" - return None - - class BaseCursor(Generic[ConnectionType]): ExecStatus = pq.ExecStatus @@ -236,8 +116,7 @@ class BaseCursor(Generic[ConnectionType]): res = self.pgresult if not res or res.status != ExecStatus.TUPLES_OK: return None - encoding = self._conn.client_encoding - return [Column(res, i, encoding) for i in range(res.nfields)] + return [Column(self, i) for i in range(res.nfields)] @property def rowcount(self) -> int: diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 36ca40051..45750143e 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -1,7 +1,9 @@ import gc -import pytest +import pickle import weakref +import pytest + import psycopg3 from psycopg3.oids import builtins @@ -343,3 +345,17 @@ class TestColumn: assert col.scale == scale assert col.display_size == dsize assert col.internal_size == isize + + def test_pickle(self, conn): + curs = conn.cursor() + curs.execute( + """select + 3.14::decimal(10,2) as pi, + 'hello'::text as hi, + '2010-02-18'::date as now + """ + ) + description = curs.description + pickled = pickle.dumps(description, pickle.HIGHEST_PROTOCOL) + unpickled = pickle.loads(pickled) + assert [tuple(d) for d in description] == [tuple(d) for d in unpickled]