# Copyright (C) 2020 The Psycopg Team
-from typing import Any, Callable, Optional, Type, Union
+from typing import Any, cast, Callable, Optional, Type, Union
from . import pq
from . import proto
elif isinstance(context, BaseConnection):
return context
elif isinstance(context, BaseCursor):
- return context.connection
+ return cast(BaseConnection, context.connection)
elif isinstance(context, Transformer):
return context.connection
else:
import threading
from types import TracebackType
from typing import Any, AsyncIterator, Callable, Iterator, List, NamedTuple
-from typing import Optional, Type, cast
+from typing import Optional, Type, TYPE_CHECKING, Union
from weakref import ref, ReferenceType
from functools import partial
from . import pq
-from . import proto
from . import cursor
from . import errors as e
from . import encodings
from .pq import TransactionStatus, ExecStatus
+from .proto import DumpersMap, LoadersMap, PQGen, RV
from .waiting import wait, wait_async
from .conninfo import make_conninfo
from .generators import notifies
logger = logging.getLogger(__name__)
package_logger = logging.getLogger("psycopg3")
-connect: Callable[[str], proto.PQGen[pq.proto.PGconn]]
-execute: Callable[[pq.proto.PGconn], proto.PQGen[List[pq.proto.PGresult]]]
+connect: Callable[[str], PQGen["PGconn"]]
+execute: Callable[["PGconn"], PQGen[List["PGresult"]]]
+
+if TYPE_CHECKING:
+ from .pq.proto import PGconn, PGresult
+ from .cursor import Cursor, AsyncCursor
if pq.__impl__ == "c":
from psycopg3_c import _psycopg3
ConnStatus = pq.ConnStatus
TransactionStatus = pq.TransactionStatus
- def __init__(self, pgconn: pq.proto.PGconn):
+ cursor_factory: Union[Type["Cursor"], Type["AsyncCursor"]]
+
+ def __init__(self, pgconn: "PGconn"):
self.pgconn = pgconn # TODO: document this
- self.cursor_factory = cursor.BaseCursor
self._autocommit = False
- self.dumpers: proto.DumpersMap = {}
- self.loaders: proto.LoadersMap = {}
+ self.dumpers: DumpersMap = {}
+ self.loaders: LoadersMap = {}
self._notice_handlers: List[NoticeHandler] = []
self._notify_handlers: List[NotifyHandler] = []
)
self._autocommit = value
- def _cursor(
- self, name: str = "", format: pq.Format = pq.Format.TEXT
- ) -> cursor.BaseCursor:
- if name:
- raise NotImplementedError
- return self.cursor_factory(self, format=format)
-
@property
def client_encoding(self) -> str:
"""The Python codec name of the connection's client encoding."""
@staticmethod
def _notice_handler(
- wself: "ReferenceType[BaseConnection]", res: pq.proto.PGresult
+ wself: "ReferenceType[BaseConnection]", res: "PGresult"
) -> None:
self = wself()
if not (self and self._notice_handler):
cursor_factory: Type[cursor.Cursor]
- def __init__(self, pgconn: pq.proto.PGconn):
+ def __init__(self, pgconn: "PGconn"):
super().__init__(pgconn)
self.lock = threading.Lock()
self.cursor_factory = cursor.Cursor
"""
Return a new `Cursor` to send commands and queries to the connection.
"""
- cur = self._cursor(name, format=format)
- return cast(cursor.Cursor, cur)
+ if name:
+ raise NotImplementedError
+
+ return self.cursor_factory(self, format=format)
def _start_query(self) -> None:
# the function is meant to be called by a cursor once the lock is taken
)
@classmethod
- def wait(
- cls, gen: proto.PQGen[proto.RV], timeout: Optional[float] = 0.1
- ) -> proto.RV:
+ def wait(cls, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV:
return wait(gen, timeout=timeout)
def _set_client_encoding(self, name: str) -> None:
cursor_factory: Type[cursor.AsyncCursor]
- def __init__(self, pgconn: pq.proto.PGconn):
+ def __init__(self, pgconn: "PGconn"):
super().__init__(pgconn)
self.lock = asyncio.Lock()
self.cursor_factory = cursor.AsyncCursor
"""
Return a new `AsyncCursor` to send commands and queries to the connection.
"""
- cur = self._cursor(name, format=format)
- return cast(cursor.AsyncCursor, cur)
+ if name:
+ raise NotImplementedError
+
+ return self.cursor_factory(self, format=format)
async def _start_query(self) -> None:
# the function is meant to be called by a cursor once the lock is taken
)
@classmethod
- async def wait(cls, gen: proto.PQGen[proto.RV]) -> proto.RV:
+ async def wait(cls, gen: PQGen[RV]) -> RV:
return await wait_async(gen)
def _set_client_encoding(self, name: str) -> None:
import re
import struct
-from typing import TYPE_CHECKING, AsyncIterator, Iterator
+from typing import TYPE_CHECKING, AsyncIterator, Iterator, Generic
from typing import Any, Dict, List, Match, Optional, Sequence, Type, Union
from types import TracebackType
-from . import pq
-from .proto import AdaptContext
+from .pq import Format
+from .proto import ConnectionType, Transformer
from .generators import copy_from, copy_to, copy_end
if TYPE_CHECKING:
- from .connection import BaseConnection, Connection, AsyncConnection
+ from .pq.proto import PGresult
+ from .connection import Connection, AsyncConnection # noqa: F401
-class BaseCopy:
+class BaseCopy(Generic[ConnectionType]):
def __init__(
self,
- context: AdaptContext,
- result: Optional[pq.proto.PGresult],
- format: pq.Format = pq.Format.TEXT,
+ connection: ConnectionType,
+ transformer: Transformer,
+ result: "PGresult",
):
- from .adapt import Transformer
-
- self._connection: Optional["BaseConnection"] = None
- self._transformer = Transformer(context)
- self.format = format
+ self.connection = connection
+ self._transformer = transformer
self.pgresult = result
+ self.format = result.binary_tuples
self._first_row = True
self._finished = False
self._encoding: str = ""
- if format == pq.Format.TEXT:
+ if self.format == Format.TEXT:
self._format_row = self._format_row_text
else:
self._format_row = self._format_row_binary
return self._finished
@property
- def connection(self) -> "BaseConnection":
- if self._connection:
- return self._connection
-
- self._connection = conn = self._transformer.connection
- if conn:
- return conn
-
- raise ValueError("no connection available")
-
- @property
- def pgresult(self) -> Optional[pq.proto.PGresult]:
+ def pgresult(self) -> Optional["PGresult"]:
return self._pgresult
@pgresult.setter
- def pgresult(self, result: Optional[pq.proto.PGresult]) -> None:
+ def pgresult(self, result: Optional["PGresult"]) -> None:
self._pgresult = result
self._transformer.pgresult = result
if (
self.pgresult is None
- or self.pgresult.binary_tuples == pq.Format.BINARY
+ or self.pgresult.binary_tuples == Format.BINARY
):
raise TypeError(
"cannot copy str data in binary mode: use bytes instead"
_bsrepl_re = re.compile(b"[\b\t\n\v\f\r\\\\]")
-class Copy(BaseCopy):
- _connection: Optional["Connection"]
-
- @property
- def connection(self) -> "Connection":
- # TODO: mypy error: "Callable[[BaseCopy], BaseConnection]" has no
- # attribute "fget"
- return BaseCopy.connection.fget(self) # type: ignore
-
+class Copy(BaseCopy["Connection"]):
def read(self) -> Optional[bytes]:
if self._finished:
return None
exc_tb: Optional[TracebackType],
) -> None:
if exc_val is None:
- if self.format == pq.Format.BINARY and not self._first_row:
+ if self.format == Format.BINARY and not self._first_row:
# send EOF only if we copied binary rows (_first_row is False)
self.write(b"\xff\xff")
self.finish()
yield data
-class AsyncCopy(BaseCopy):
- _connection: Optional["AsyncConnection"]
-
- @property
- def connection(self) -> "AsyncConnection":
- return BaseCopy.connection.fget(self) # type: ignore
-
+class AsyncCopy(BaseCopy["AsyncConnection"]):
async def read(self) -> Optional[bytes]:
if self._finished:
return None
exc_tb: Optional[TracebackType],
) -> None:
if exc_val is None:
- if self.format == pq.Format.BINARY and not self._first_row:
+ if self.format == Format.BINARY and not self._first_row:
# send EOF only if we copied binary rows (_first_row is False)
await self.write(b"\xff\xff")
await self.finish()
# Copyright (C) 2020 The Psycopg Team
from types import TracebackType
-from typing import Any, AsyncIterator, Callable, Iterator, List, Mapping
+from typing import (
+ Any,
+ AsyncIterator,
+ Callable,
+ Generic,
+ Iterator,
+ List,
+ Mapping,
+)
from typing import Optional, Sequence, Type, TYPE_CHECKING, Union
from operator import attrgetter
from . import errors as e
from . import pq
from . import sql
-from . import proto
from .oids import builtins
from .copy import Copy, AsyncCopy
-from .proto import Query, Params, DumpersMap, LoadersMap, PQGen
+from .proto import ConnectionType, Query, Params, DumpersMap, LoadersMap, PQGen
from .utils.queries import PostgresQuery
if TYPE_CHECKING:
- from .connection import BaseConnection, Connection, AsyncConnection
+ from .proto import Transformer
+ from .pq.proto import PGconn, PGresult
+ from .connection import Connection, AsyncConnection # noqa: F401
-execute: Callable[[pq.proto.PGconn], PQGen[List[pq.proto.PGresult]]]
+execute: Callable[["PGconn"], PQGen[List["PGresult"]]]
if pq.__impl__ == "c":
from psycopg3_c import _psycopg3
class Column(Sequence[Any]):
- def __init__(self, pgresult: pq.proto.PGresult, index: int, encoding: str):
+ def __init__(self, pgresult: "PGresult", index: int, encoding: str):
self._pgresult = pgresult
self._index = index
self._encoding = encoding
return None
-class BaseCursor:
+class BaseCursor(Generic[ConnectionType]):
ExecStatus = pq.ExecStatus
- _transformer: proto.Transformer
+ _transformer: "Transformer"
def __init__(
- self, connection: "BaseConnection", format: pq.Format = pq.Format.TEXT
+ self,
+ connection: ConnectionType,
+ format: pq.Format = pq.Format.TEXT,
):
self.connection = connection
self.format = format
self._closed = False
def _reset(self) -> None:
- self._results: List[pq.proto.PGresult] = []
+ self._results: List["PGresult"] = []
self.pgresult = None
self._pos = 0
self._iresult = 0
return res.status if res else None
@property
- def pgresult(self) -> Optional[pq.proto.PGresult]:
+ def pgresult(self) -> Optional["PGresult"]:
"""The `~psycopg3.pq.PGresult` exposed by the cursor."""
return self._pgresult
@pgresult.setter
- def pgresult(self, result: Optional[pq.proto.PGresult]) -> None:
+ def pgresult(self, result: Optional["PGresult"]) -> None:
self._pgresult = result
if result and self._transformer:
self._transformer.pgresult = result
return None
def _start_query(self) -> None:
- from .adapt import Transformer
+ from . import adapt
if self.closed:
raise e.InterfaceError("the cursor is closed")
)
self._reset()
- self._transformer = Transformer(self)
+ self._transformer = adapt.Transformer(self)
def _execute_send(
self, query: Query, vars: Optional[Params], no_pqexec: bool = False
# one query in one go
self.connection.pgconn.send_query(pgq.query)
- def _execute_results(self, results: Sequence[pq.proto.PGresult]) -> None:
+ def _execute_results(self, results: Sequence["PGresult"]) -> None:
"""
Implement part of execute() after waiting common to sync and async
"""
qparts.append(sql.SQL(")"))
return sql.Composed(qparts)
- def _check_copy_results(
- self, results: Sequence[pq.proto.PGresult]
- ) -> None:
+ def _check_copy_results(self, results: Sequence["PGresult"]) -> None:
"""
Check that the value returned in a copy() operation is a legit COPY.
"""
)
-class Cursor(BaseCursor):
- connection: "Connection"
-
- def __init__(
- self, connection: "Connection", format: pq.Format = pq.Format.TEXT
- ):
- super().__init__(connection, format=format)
-
+class Cursor(BaseCursor["Connection"]):
def __enter__(self) -> "Cursor":
return self
self._execute_send(statement, vars, no_pqexec=True)
gen = execute(self.connection.pgconn)
results = self.connection.wait(gen)
- tx = self._transformer
self._check_copy_results(results)
return Copy(
- context=tx, result=results[0], format=results[0].binary_tuples
+ connection=self.connection,
+ transformer=self._transformer,
+ result=results[0],
)
-class AsyncCursor(BaseCursor):
- connection: "AsyncConnection"
-
- def __init__(
- self, connection: "AsyncConnection", format: pq.Format = pq.Format.TEXT
- ):
- super().__init__(connection, format=format)
-
+class AsyncCursor(BaseCursor["AsyncConnection"]):
async def __aenter__(self) -> "AsyncCursor":
return self
self._execute_send(statement, vars, no_pqexec=True)
gen = execute(self.connection.pgconn)
results = await self.connection.wait(gen)
- tx = self._transformer
self._check_copy_results(results)
return AsyncCopy(
- context=tx, result=results[0], format=results[0].binary_tuples
+ connection=self.connection,
+ transformer=self._transformer,
+ result=results[0],
)
# Copyright (C) 2020 The Psycopg Team
-from typing import Any, Callable, Optional, Sequence, NewType
+from typing import Any, Callable, Optional, Sequence
from ctypes import Array, pointer
from ctypes import c_char, c_char_p, c_int, c_ubyte, c_uint, c_ulong
Query = Union[str, bytes, "Composable"]
Params = Union[Sequence[Any], Mapping[str, Any]]
+ConnectionType = TypeVar("ConnectionType", bound="BaseConnection")
# Waiting protocol types