import threading
from types import TracebackType
from typing import Any, AsyncIterator, Callable, Iterator, List, NamedTuple
-from typing import Optional, Type, TYPE_CHECKING, Union
+from typing import Optional, Type, TYPE_CHECKING, TypeVar
from weakref import ref, ReferenceType
from functools import partial
from contextlib import contextmanager
from . import pq
from . import cursor
from . import errors as e
+from . import waiting
from . import encodings
from .pq import TransactionStatus, ExecStatus, Format
from .sql import Composable
from .proto import DumpersMap, LoadersMap, PQGen, PQGenConn, RV, Query, Params
-from .waiting import wait, wait_async
from .conninfo import make_conninfo
from .generators import notifies
from .transaction import Transaction, AsyncTransaction
NoticeHandler = Callable[[e.Diagnostic], None]
NotifyHandler = Callable[[Notify], None]
+C = TypeVar("C", bound="BaseConnection")
+
class BaseConnection:
"""
def _set_client_encoding(self, name: str) -> None:
raise NotImplementedError
+ def _set_client_encoding_gen(self, name: str) -> PQGen[None]:
+ self.pgconn.send_query_params(
+ b"select set_config('client_encoding', $1, false)",
+ [encodings.py2pg(name)],
+ )
+ (result,) = yield from execute(self.pgconn)
+ if result.status != ExecStatus.TUPLES_OK:
+ raise e.error_from_result(result, encoding=self.client_encoding)
+
def cancel(self) -> None:
"""Cancel the current operation on the connection."""
c = self.pgconn.get_cancel()
for cb in self._notify_handlers:
cb(n)
+ # Generators to perform high-level operations on the connection
+ #
+ # These operations are expressed in terms of non-blocking generators
+ # and the task of waiting when needed (when the generators yield) is left
+ # to the connections subclass, which might wait either in blocking mode
+ # or through asyncio.
+ #
+ # All these generators assume exclusive acces to the connection: subclasses
+ # should have a lock and hold it before calling and consuming them.
+
+ @classmethod
+ def _connect_gen(
+ cls: Type[C],
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ **kwargs: Any,
+ ) -> PQGenConn[C]:
+ """Generator to connect to the database and create a new instance."""
+ conninfo = make_conninfo(conninfo, **kwargs)
+ pgconn = yield from connect(conninfo)
+ conn = cls(pgconn)
+ conn._autocommit = autocommit
+ return conn
+
+ def _exec_command(self, command: Query) -> PQGen[None]:
+ """
+ Generator to send a command and receive the result to the backend.
+
+ Only used to implement internal commands such as commit, returning
+ no result. The cursor can do more complex stuff.
+ """
+ if self.pgconn.status != self.ConnStatus.OK:
+ if self.pgconn.status == self.ConnStatus.BAD:
+ raise e.OperationalError("the connection is closed")
+ raise e.InterfaceError(
+ f"cannot execute operations: the connection is"
+ f" in status {self.pgconn.status}"
+ )
+
+ if isinstance(command, str):
+ command = command.encode(self.client_encoding)
+ elif isinstance(command, Composable):
+ command = command.as_bytes(self)
+
+ self.pgconn.send_query(command)
+ result = (yield from execute(self.pgconn))[-1]
+ if result.status != ExecStatus.COMMAND_OK:
+ if result.status == ExecStatus.FATAL_ERROR:
+ raise e.error_from_result(
+ result, encoding=self.client_encoding
+ )
+ else:
+ raise e.InterfaceError(
+ f"unexpected result {ExecStatus(result.status).name}"
+ f" from command {command.decode('utf8')!r}"
+ )
+
+ def _start_query(self) -> PQGen[None]:
+ """Generator to start a transaction if necessary."""
+ if self._autocommit:
+ return
+
+ if self.pgconn.transaction_status != TransactionStatus.IDLE:
+ return
+
+ yield from self._exec_command(b"begin")
+
+ def _commit_gen(self) -> PQGen[None]:
+ """Generator implementing `Connection.commit()`."""
+ if self._savepoints:
+ raise e.ProgrammingError(
+ "Explicit commit() forbidden within a Transaction "
+ "context. (Transaction will be automatically committed "
+ "on successful exit from context.)"
+ )
+ if self.pgconn.transaction_status == TransactionStatus.IDLE:
+ return
+
+ yield from self._exec_command(b"commit")
+
+ def _rollback_gen(self) -> PQGen[None]:
+ """Generator implementing `Connection.rollback()`."""
+ if self._savepoints:
+ raise e.ProgrammingError(
+ "Explicit rollback() forbidden within a Transaction "
+ "context. (Either raise Rollback() or allow "
+ "an exception to propagate out of the context.)"
+ )
+ if self.pgconn.transaction_status == TransactionStatus.IDLE:
+ return
+
+ yield from self._exec_command(b"rollback")
+
class Connection(BaseConnection):
"""
TODO: connection_timeout to be implemented.
"""
-
- conninfo = make_conninfo(conninfo, **kwargs)
- gen = connect(conninfo)
- pgconn = cls.wait(gen)
- conn = cls(pgconn)
- conn._autocommit = autocommit
- return conn
+ return cls._wait_conn(
+ cls._connect_gen(conninfo, autocommit=autocommit, **kwargs)
+ )
def __enter__(self) -> "Connection":
return self
return self.cursor_factory(self, format=format)
- def _start_query(self) -> None:
- # the function is meant to be called by a cursor once the lock is taken
- if self._autocommit:
- return
-
- if self.pgconn.transaction_status != TransactionStatus.IDLE:
- return
-
- self._exec_command(b"begin")
-
def execute(
self, query: Query, params: Optional[Params] = None
) -> "Cursor":
def commit(self) -> None:
"""Commit any pending transaction to the database."""
with self.lock:
- if self._savepoints:
- raise e.ProgrammingError(
- "Explicit commit() forbidden within a Transaction "
- "context. (Transaction will be automatically committed "
- "on successful exit from context.)"
- )
- if self.pgconn.transaction_status == TransactionStatus.IDLE:
- return
- self._exec_command(b"commit")
+ self.wait(self._commit_gen())
def rollback(self) -> None:
"""Roll back to the start of any pending transaction."""
with self.lock:
- if self._savepoints:
- raise e.ProgrammingError(
- "Explicit rollback() forbidden within a Transaction "
- "context. (Either raise Rollback() or allow "
- "an exception to propagate out of the context.)"
- )
- if self.pgconn.transaction_status == TransactionStatus.IDLE:
- return
- self._exec_command(b"rollback")
-
- def _exec_command(self, command: Query) -> None:
- # Caller must hold self.lock
-
- if isinstance(command, str):
- command = command.encode(self.client_encoding)
- elif isinstance(command, Composable):
- command = command.as_string(self).encode(self.client_encoding)
-
- self.pgconn.send_query(command)
- result = self.wait(execute(self.pgconn))[-1]
- if result.status != ExecStatus.COMMAND_OK:
- if result.status == ExecStatus.FATAL_ERROR:
- raise e.error_from_result(
- result, encoding=self.client_encoding
- )
- else:
- raise e.InterfaceError(
- f"unexpected result {ExecStatus(result.status).name}"
- f" from command {command.decode('utf8')!r}"
- )
+ self.wait(self._rollback_gen())
@contextmanager
def transaction(
with Transaction(self, savepoint_name, force_rollback) as tx:
yield tx
- @classmethod
- def wait(
- cls,
- gen: Union[PQGen[RV], PQGenConn[RV]],
- timeout: Optional[float] = 0.1,
- ) -> RV:
- return wait(gen, timeout=timeout)
-
- def _set_client_encoding(self, name: str) -> None:
- with self.lock:
- self.pgconn.send_query_params(
- b"select set_config('client_encoding', $1, false)",
- [encodings.py2pg(name)],
- )
- gen = execute(self.pgconn)
- (result,) = self.wait(gen)
- if result.status != ExecStatus.TUPLES_OK:
- raise e.error_from_result(
- result, encoding=self.client_encoding
- )
-
def notifies(self) -> Iterator[Notify]:
"""
Yield `Notify` objects as soon as they are received from the database.
)
yield n
+ def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV:
+ """
+ Consume a generator operating on the connection.
+
+ The function must be used on generators that don't change connection
+ fd (i.e. not on connect and reset).
+ """
+ return waiting.wait(gen, self.pgconn.socket, timeout=timeout)
+
+ @classmethod
+ def _wait_conn(
+ cls, gen: PQGenConn[RV], timeout: Optional[float] = 0.1
+ ) -> RV:
+ """Consume a connection generator."""
+ return waiting.wait_conn(gen, timeout=timeout)
+
def _set_autocommit(self, value: bool) -> None:
with self.lock:
super()._set_autocommit(value)
+ def _set_client_encoding(self, name: str) -> None:
+ with self.lock:
+ self.wait(self._set_client_encoding_gen(name))
+
class AsyncConnection(BaseConnection):
"""
async def connect(
cls, conninfo: str = "", *, autocommit: bool = False, **kwargs: Any
) -> "AsyncConnection":
- conninfo = make_conninfo(conninfo, **kwargs)
- gen = connect(conninfo)
- pgconn = await cls.wait(gen)
- conn = cls(pgconn)
- conn._autocommit = autocommit
- return conn
+ return await cls._wait_conn(
+ cls._connect_gen(conninfo, autocommit=autocommit, **kwargs)
+ )
async def __aenter__(self) -> "AsyncConnection":
return self
return self.cursor_factory(self, format=format)
- async def _start_query(self) -> None:
- # the function is meant to be called by a cursor once the lock is taken
- if self._autocommit:
- return
-
- if self.pgconn.transaction_status != TransactionStatus.IDLE:
- return
-
- await self._exec_command(b"begin")
-
async def execute(
self, query: Query, params: Optional[Params] = None
) -> "AsyncCursor":
async def commit(self) -> None:
async with self.lock:
- if self._savepoints:
- raise e.ProgrammingError(
- "Explicit commit() forbidden within a Transaction "
- "context. (Transaction will be automatically committed "
- "on successful exit from context.)"
- )
- if self.pgconn.transaction_status == TransactionStatus.IDLE:
- return
- await self._exec_command(b"commit")
+ await self.wait(self._commit_gen())
async def rollback(self) -> None:
async with self.lock:
- if self._savepoints:
- raise e.ProgrammingError(
- "Explicit rollback() forbidden within a Transaction "
- "context. (Either raise Rollback() or allow "
- "an exception to propagate out of the context.)"
- )
- if self.pgconn.transaction_status == TransactionStatus.IDLE:
- return
- await self._exec_command(b"rollback")
-
- async def _exec_command(self, command: Query) -> None:
- # Caller must hold self.lock
-
- if isinstance(command, str):
- command = command.encode(self.client_encoding)
- elif isinstance(command, Composable):
- command = command.as_string(self).encode(self.client_encoding)
-
- self.pgconn.send_query(command)
- result = (await self.wait(execute(self.pgconn)))[-1]
- if result.status != ExecStatus.COMMAND_OK:
- if result.status == ExecStatus.FATAL_ERROR:
- raise e.error_from_result(
- result, encoding=self.client_encoding
- )
- else:
- raise e.InterfaceError(
- f"unexpected result {ExecStatus(result.status).name}"
- f" from command {command.decode('utf8')!r}"
- )
+ await self.wait(self._rollback_gen())
@asynccontextmanager
async def transaction(
async with tx:
yield tx
+ async def notifies(self) -> AsyncIterator[Notify]:
+ while 1:
+ async with self.lock:
+ ns = await self.wait(notifies(self.pgconn))
+ enc = self.client_encoding
+ for pgn in ns:
+ n = Notify(
+ pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid
+ )
+ yield n
+
+ async def wait(self, gen: PQGen[RV]) -> RV:
+ return await waiting.wait_async(gen, self.pgconn.socket)
+
@classmethod
- async def wait(cls, gen: Union[PQGen[RV], PQGenConn[RV]]) -> RV:
- return await wait_async(gen)
+ async def _wait_conn(cls, gen: PQGenConn[RV]) -> RV:
+ return await waiting.wait_async_conn(gen)
def _set_client_encoding(self, name: str) -> None:
raise AttributeError(
async def set_client_encoding(self, name: str) -> None:
"""Async version of the `~Connection.client_encoding` setter."""
async with self.lock:
- self.pgconn.send_query_params(
- b"select set_config('client_encoding', $1, false)",
- [encodings.py2pg(name)],
- )
- gen = execute(self.pgconn)
- (result,) = await self.wait(gen)
- if result.status != ExecStatus.TUPLES_OK:
- raise e.error_from_result(
- result, encoding=self.client_encoding
- )
-
- async def notifies(self) -> AsyncIterator[Notify]:
- while 1:
- async with self.lock:
- ns = await self.wait(notifies(self.pgconn))
- enc = self.client_encoding
- for pgn in ns:
- n = Notify(
- pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid
- )
- yield n
+ await self.wait(self._set_client_encoding_gen(name))
def _set_autocommit(self, value: bool) -> None:
raise AttributeError(
from . import errors as e
from . import pq
-from .pq import ConnStatus, ExecStatus, Format
+from .pq import ExecStatus, Format
from .copy import Copy, AsyncCopy
from .proto import ConnectionType, Query, Params, DumpersMap, LoadersMap, PQGen
from ._column import Column
else:
return None
- def _start_query(self) -> None:
+ #
+ # Generators for the high level operations on the cursor
+ #
+ # Like for sync/async connections, these are implemented as generators
+ # so that different concurrency strategies (threads,asyncio) can use their
+ # own way of waiting (or better, `connection.wait()`).
+ #
+
+ def _execute_gen(
+ self, query: Query, params: Optional[Params] = None
+ ) -> PQGen[None]:
+ """Generator implementing `Cursor.execute()`."""
+ yield from self._start_query()
+ self._execute_send(query, params)
+ results = yield from execute(self._conn.pgconn)
+ self._execute_results(results)
+
+ def _executemany_gen(
+ self, query: Query, params_seq: Sequence[Params]
+ ) -> PQGen[None]:
+ """Generator implementing `Cursor.executemany()`."""
+ yield from self._start_query()
+ first = True
+ for params in params_seq:
+ if first:
+ pgq = self._send_prepare(b"", query, params)
+ (result,) = yield from execute(self._conn.pgconn)
+ if result.status == ExecStatus.FATAL_ERROR:
+ raise e.error_from_result(
+ result, encoding=self._conn.client_encoding
+ )
+ else:
+ pgq.dump(params)
+
+ self._send_query_prepared(b"", pgq)
+ (result,) = yield from execute(self._conn.pgconn)
+ self._execute_results((result,))
+
+ def _start_query(self) -> PQGen[None]:
+ """Generator to start the processing of a query.
+
+ It is implemented as generator because it may send additional queries,
+ such as `begin`.
+ """
from . import adapt
if self.closed:
raise e.InterfaceError("the cursor is closed")
- if self._conn.closed:
- raise e.InterfaceError("the connection is closed")
-
- if self._conn.pgconn.status != ConnStatus.OK:
- raise e.InterfaceError(
- f"cannot execute operations: the connection is"
- f" in status {self._conn.pgconn.status}"
- )
-
self._reset()
self._transformer = adapt.Transformer(self)
+ yield from self._conn._start_query()
+
+ def _start_copy_gen(self, statement: Query) -> PQGen[None]:
+ """Generator implementing sending a command for `Cursor.copy()."""
+ yield from self._start_query()
+ # Make sure to avoid PQexec to avoid receiving a mix of COPY and
+ # other operations.
+ self._execute_send(statement, None, no_pqexec=True)
+ (result,) = yield from execute(self._conn.pgconn)
+ self._check_copy_result(result)
+ self.pgresult = result # will set it on the transformer too
def _execute_send(
self, query: Query, params: Optional[Params], no_pqexec: bool = False
) -> None:
"""
- Implement part of execute() before waiting common to sync and async
+ Implement part of execute() before waiting common to sync and async.
+
+ This is not a generator, but a normal, non-blocking function.
"""
pgq = PostgresQuery(self._transformer)
pgq.convert(query, params)
def _execute_results(self, results: Sequence["PGresult"]) -> None:
"""
Implement part of execute() after waiting common to sync and async
+
+ This is not a generator, but a normal, non-blocking function.
"""
if not results:
raise e.InternalError("got no result from the query")
def _send_prepare(
self, name: bytes, query: Query, params: Optional[Params]
) -> PostgresQuery:
- """
- Implement part of execute() before waiting common to sync and async
- """
pgq = PostgresQuery(self._transformer)
pgq.convert(query, params)
"the last operation didn't produce a result"
)
- def _check_copy_results(self, results: Sequence["PGresult"]) -> None:
+ def _check_copy_result(self, result: "PGresult") -> None:
"""
Check that the value returned in a copy() operation is a legit COPY.
"""
- if len(results) != 1:
- raise e.InternalError(
- f"expected 1 result from copy, got {len(results)} instead"
- )
-
- result = results[0]
status = result.status
if status in (ExecStatus.COPY_IN, ExecStatus.COPY_OUT):
return
Execute a query or command to the database.
"""
with self._conn.lock:
- self._start_query()
- self._conn._start_query()
- self._execute_send(query, params)
- gen = execute(self._conn.pgconn)
- results = self._conn.wait(gen)
- self._execute_results(results)
+ self._conn.wait(self._execute_gen(query, params))
return self
def executemany(self, query: Query, params_seq: Sequence[Params]) -> None:
Execute the same command with a sequence of input data.
"""
with self._conn.lock:
- self._start_query()
- self._conn._start_query()
- first = True
- for params in params_seq:
- if first:
- pgq = self._send_prepare(b"", query, params)
- gen = execute(self._conn.pgconn)
- (result,) = self._conn.wait(gen)
- if result.status == ExecStatus.FATAL_ERROR:
- raise e.error_from_result(
- result, encoding=self._conn.client_encoding
- )
- else:
- pgq.dump(params)
-
- self._send_query_prepared(b"", pgq)
- gen = execute(self._conn.pgconn)
- (result,) = self._conn.wait(gen)
- self._execute_results((result,))
+ self._conn.wait(self._executemany_gen(query, params_seq))
def fetchone(self) -> Optional[Sequence[Any]]:
"""
"""
Initiate a :sql:`COPY` operation and return an object to manage it.
"""
- with self._start_copy(statement) as copy:
- yield copy
-
- def _start_copy(self, statement: Query) -> Copy:
with self._conn.lock:
- self._start_query()
- self._conn._start_query()
- # Make sure to avoid PQexec to avoid receiving a mix of COPY and
- # other operations.
- self._execute_send(statement, None, no_pqexec=True)
- gen = execute(self._conn.pgconn)
- results = self._conn.wait(gen)
- self._check_copy_results(results)
- self.pgresult = results[0] # will set it on the transformer too
+ self._conn.wait(self._start_copy_gen(statement))
- return Copy(self)
+ with Copy(self) as copy:
+ yield copy
class AsyncCursor(BaseCursor["AsyncConnection"]):
self, query: Query, params: Optional[Params] = None
) -> "AsyncCursor":
async with self._conn.lock:
- self._start_query()
- await self._conn._start_query()
- self._execute_send(query, params)
- gen = execute(self._conn.pgconn)
- results = await self._conn.wait(gen)
- self._execute_results(results)
+ await self._conn.wait(self._execute_gen(query, params))
return self
async def executemany(
self, query: Query, params_seq: Sequence[Params]
) -> None:
async with self._conn.lock:
- self._start_query()
- await self._conn._start_query()
- first = True
- for params in params_seq:
- if first:
- pgq = self._send_prepare(b"", query, params)
- gen = execute(self._conn.pgconn)
- (result,) = await self._conn.wait(gen)
- if result.status == ExecStatus.FATAL_ERROR:
- raise e.error_from_result(
- result, encoding=self._conn.client_encoding
- )
- else:
- pgq.dump(params)
-
- self._send_query_prepared(b"", pgq)
- gen = execute(self._conn.pgconn)
- (result,) = await self._conn.wait(gen)
- self._execute_results((result,))
+ await self._conn.wait(self._executemany_gen(query, params_seq))
async def fetchone(self) -> Optional[Sequence[Any]]:
self._check_result()
@asynccontextmanager
async def copy(self, statement: Query) -> AsyncIterator[AsyncCopy]:
- copy = await self._start_copy(statement)
- async with copy:
- yield copy
-
- async def _start_copy(self, statement: Query) -> AsyncCopy:
async with self._conn.lock:
- self._start_query()
- await self._conn._start_query()
- # Make sure to avoid PQexec to avoid receiving a mix of COPY and
- # other operations.
- self._execute_send(statement, None, no_pqexec=True)
- gen = execute(self._conn.pgconn)
- results = await self._conn.wait(gen)
- self._check_copy_results(results)
- self.pgresult = results[0] # will set it on the transformer too
-
- return AsyncCopy(self)
+ await self._conn.wait(self._start_copy_gen(statement))
+
+ async with AsyncCopy(self) as copy:
+ yield copy
class NamedCursorMixin:
After this generator has finished you may want to cycle using `_fetch()`
to retrieve the results available.
"""
- yield pgconn.socket
while 1:
f = pgconn.flush()
if f == 0:
def notifies(pgconn: PGconn) -> PQGen[List[pq.PGnotify]]:
- yield pgconn.socket
yield Wait.R
pgconn.consume_input()
def copy_from(pgconn: PGconn) -> PQGen[Union[bytes, PGresult]]:
- yield pgconn.socket
while 1:
nbytes, data = pgconn.get_copy_data(1)
if nbytes != 0:
def copy_to(pgconn: PGconn, buffer: bytes) -> PQGen[None]:
- yield pgconn.socket
# Retry enqueuing data until successful
while pgconn.put_copy_data(buffer) == 0:
yield Wait.W
def copy_end(pgconn: PGconn, error: Optional[bytes]) -> PQGen[PGresult]:
- yield pgconn.socket
# Retry enqueuing end copy message until successful
while pgconn.put_copy_end(error) == 0:
yield Wait.W
This can happen in connection and reset, but not in normal querying.
"""
-PQGen = Generator[Union[int, "Wait"], "Ready", RV]
+PQGen = Generator["Wait", "Ready", RV]
"""Generator for processes where the connection file number won't change.
The first item generated is the file descriptor; following items are be the
import logging
from types import TracebackType
-from typing import Generic, List, Optional, Type, Union, TYPE_CHECKING
+from typing import Generic, Optional, Type, Union, TYPE_CHECKING
from . import sql
from .pq import TransactionStatus
-from .proto import ConnectionType
+from .proto import ConnectionType, PQGen
if TYPE_CHECKING:
from .connection import Connection, AsyncConnection # noqa: F401
args.append("force_rollback=True")
return f"{self.__class__.__qualname__}({', '.join(args)})"
- def _enter_commands(self) -> List[bytes]:
+ def _enter_gen(self) -> PQGen[None]:
if not self._yolo:
raise TypeError("transaction blocks can be used only once")
else:
)
self._conn._savepoints.append(self._savepoint_name)
- return commands
+ return self._conn._exec_command(b"; ".join(commands))
- def _commit_commands(self) -> List[bytes]:
+ def _exit_gen(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> PQGen[bool]:
+ if not exc_val and not self.force_rollback:
+ yield from self._commit_gen()
+ return False
+ else:
+ return (yield from self._rollback_gen(exc_val))
+
+ def _commit_gen(self) -> PQGen[None]:
assert self._conn._savepoints[-1] == self._savepoint_name
self._conn._savepoints.pop()
assert not self._conn._savepoints
commands.append(b"commit")
- return commands
+ return self._conn._exec_command(b"; ".join(commands))
+
+ def _rollback_gen(self, exc_val: Optional[BaseException]) -> PQGen[bool]:
+ if isinstance(exc_val, Rollback):
+ _log.debug(
+ f"{self._conn}: Explicit rollback from: ", exc_info=True
+ )
- def _rollback_commands(self) -> List[bytes]:
assert self._conn._savepoints[-1] == self._savepoint_name
self._conn._savepoints.pop()
assert not self._conn._savepoints
commands.append(b"rollback")
- return commands
+ yield from self._conn._exec_command(b"; ".join(commands))
+
+ if isinstance(exc_val, Rollback):
+ if not exc_val.transaction or exc_val.transaction is self:
+ return True # Swallow the exception
+
+ return False
class Transaction(BaseTransaction["Connection"]):
def __enter__(self) -> "Transaction":
with self._conn.lock:
- self._execute(self._enter_commands())
+ self._conn.wait(self._enter_gen())
return self
def __exit__(
exc_tb: Optional[TracebackType],
) -> bool:
with self._conn.lock:
- if not exc_val and not self.force_rollback:
- self._commit()
- return False
- else:
- return self._rollback(exc_val)
-
- def _commit(self) -> None:
- """Commit changes made in the transaction context."""
- self._execute(self._commit_commands())
-
- def _rollback(self, exc_val: Optional[BaseException]) -> bool:
- # Rollback changes made in the transaction context
- if isinstance(exc_val, Rollback):
- _log.debug(
- f"{self._conn}: Explicit rollback from: ", exc_info=True
- )
-
- self._execute(self._rollback_commands())
-
- if isinstance(exc_val, Rollback):
- if exc_val.transaction in (self, None):
- return True # Swallow the exception
-
- return False
-
- def _execute(self, commands: List[bytes]) -> None:
- self._conn._exec_command(b"; ".join(commands))
+ return self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb))
class AsyncTransaction(BaseTransaction["AsyncConnection"]):
async def __aenter__(self) -> "AsyncTransaction":
async with self._conn.lock:
- await self._execute(self._enter_commands())
-
+ await self._conn.wait(self._enter_gen())
return self
async def __aexit__(
exc_tb: Optional[TracebackType],
) -> bool:
async with self._conn.lock:
- if not exc_val and not self.force_rollback:
- await self._commit()
- return False
- else:
- return await self._rollback(exc_val)
-
- async def _commit(self) -> None:
- """Commit changes made in the transaction context."""
- await self._execute(self._commit_commands())
-
- async def _rollback(self, exc_val: Optional[BaseException]) -> bool:
- # Rollback changes made in the transaction context
- if isinstance(exc_val, Rollback):
- _log.debug(
- f"{self._conn}: Explicit rollback from: ", exc_info=True
+ return await self._conn.wait(
+ self._exit_gen(exc_type, exc_val, exc_tb)
)
-
- await self._execute(self._rollback_commands())
-
- if isinstance(exc_val, Rollback):
- if exc_val.transaction in (self, None):
- return True # Swallow the exception
-
- return False
-
- async def _execute(self, commands: List[bytes]) -> None:
- await self._conn._exec_command(b"; ".join(commands))
from enum import IntEnum
-from typing import Optional, Union
+from typing import Optional
from asyncio import get_event_loop, Event
from selectors import DefaultSelector, EVENT_READ, EVENT_WRITE
W = EVENT_WRITE
-def wait(
- gen: Union[PQGen[RV], PQGenConn[RV]], timeout: Optional[float] = None
-) -> RV:
+def wait(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV:
"""
- Wait for a generator using the best option available on the platform.
+ Wait for a generator using the best strategy available.
+
+ :param gen: a generator performing database operations and yielding
+ `Ready` values when it would block.
+ :param fileno: the file descriptor to wait on.
+ :param timeout: timeout (in seconds) to check for other interrupt, e.g.
+ to allow Ctrl-C.
+ :type timeout: float
+ :return: whatever *gen* returns on completion.
+
+ Consume *gen*, scheduling `fileno` for completion when it is reported to
+ block. Once ready again send the ready state back to *gen*.
+ """
+ sel = DefaultSelector()
+ try:
+ s = next(gen)
+ while 1:
+ sel.register(fileno, s)
+ ready = None
+ while not ready:
+ ready = sel.select(timeout=timeout)
+ sel.unregister(fileno)
+ s = gen.send(ready[0][1])
+
+ except StopIteration as ex:
+ rv: RV = ex.args[0] if ex.args else None
+ return rv
+
+
+def wait_conn(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV:
+ """
+ Wait for a connection generator using the best strategy available.
:param gen: a generator performing database operations and yielding
(fd, `Ready`) pairs when it would block.
to allow Ctrl-C.
:type timeout: float
:return: whatever *gen* returns on completion.
+
+ Behave like in `wait()`, but take the fileno to wait from the generator
+ itself, which might change during processing.
"""
- fd: int
- s: Wait
sel = DefaultSelector()
try:
- # Use the first generated item to tell if it's a PQgen or PQgenConn.
- # Note: mypy gets confused by the behaviour of this generator.
- item = next(gen)
- if isinstance(item, tuple):
- fd, s = item
- while 1:
- sel.register(fd, s)
- ready = None
- while not ready:
- ready = sel.select(timeout=timeout)
- sel.unregister(fd)
-
- assert len(ready) == 1
- fd, s = gen.send(ready[0][1])
- else:
- fd = item # type: ignore[assignment]
- s = next(gen) # type: ignore[assignment]
- while 1:
- sel.register(fd, s)
- ready = None
- while not ready:
- ready = sel.select(timeout=timeout)
- sel.unregister(fd)
-
- assert len(ready) == 1
- s = gen.send(ready[0][1]) # type: ignore[arg-type,assignment]
+ fileno, s = next(gen)
+ while 1:
+ sel.register(fileno, s)
+ ready = None
+ while not ready:
+ ready = sel.select(timeout=timeout)
+ sel.unregister(fileno)
+ fileno, s = gen.send(ready[0][1])
except StopIteration as ex:
rv: RV = ex.args[0] if ex.args else None
return rv
-async def wait_async(gen: Union[PQGen[RV], PQGenConn[RV]]) -> RV:
+async def wait_async(gen: PQGen[RV], fileno: int) -> RV:
"""
Coroutine waiting for a generator to complete.
- *gen* is expected to generate tuples (fd, status). consume it and block
- according to the status until fd is ready. Send back the ready state
- to the generator.
+ :param gen: a generator performing database operations and yielding
+ `Ready` values when it would block.
+ :param fileno: the file descriptor to wait on.
+ :return: whatever *gen* returns on completion.
+
+ Behave like in `wait()`, but exposing an `asyncio` interface.
+ """
+ # Use an event to block and restart after the fd state changes.
+ # Not sure this is the best implementation but it's a start.
+ ev = Event()
+ loop = get_event_loop()
+ ready: Ready
+ s: Wait
+
+ def wakeup(state: Ready) -> None:
+ nonlocal ready
+ ready = state
+ ev.set()
+
+ try:
+ s = next(gen)
+ while 1:
+ ev.clear()
+ if s == Wait.R:
+ loop.add_reader(fileno, wakeup, Ready.R)
+ await ev.wait()
+ loop.remove_reader(fileno)
+ elif s == Wait.W:
+ loop.add_writer(fileno, wakeup, Ready.W)
+ await ev.wait()
+ loop.remove_writer(fileno)
+ elif s == Wait.RW:
+ loop.add_reader(fileno, wakeup, Ready.R)
+ loop.add_writer(fileno, wakeup, Ready.W)
+ await ev.wait()
+ loop.remove_reader(fileno)
+ loop.remove_writer(fileno)
+ else:
+ raise e.InternalError("bad poll status: %s")
+ s = gen.send(ready)
+
+ except StopIteration as ex:
+ rv: RV = ex.args[0] if ex.args else None
+ return rv
+
+
+async def wait_async_conn(gen: PQGenConn[RV]) -> RV:
+ """
+ Coroutine waiting for a connection generator to complete.
+
+ :param gen: a generator performing database operations and yielding
+ (fd, `Ready`) pairs when it would block.
+ :param timeout: timeout (in seconds) to check for other interrupt, e.g.
+ to allow Ctrl-C.
+ :return: whatever *gen* returns on completion.
- Return what the generator eventually returned.
+ Behave like in `wait()`, but take the fileno to wait from the generator
+ itself, which might change during processing.
"""
# Use an event to block and restart after the fd state changes.
# Not sure this is the best implementation but it's a start.
ev = Event()
loop = get_event_loop()
ready: Ready
- fd: int
s: Wait
def wakeup(state: Ready) -> None:
ev.set()
try:
- # Use the first generated item to tell if it's a PQgen or PQgenConn.
- # Note: mypy gets confused by the behaviour of this generator.
- item = next(gen)
- if isinstance(item, tuple):
- fd, s = item
- while 1:
- ev.clear()
- if s == Wait.R:
- loop.add_reader(fd, wakeup, Ready.R)
- await ev.wait()
- loop.remove_reader(fd)
- elif s == Wait.W:
- loop.add_writer(fd, wakeup, Ready.W)
- await ev.wait()
- loop.remove_writer(fd)
- elif s == Wait.RW:
- loop.add_reader(fd, wakeup, Ready.R)
- loop.add_writer(fd, wakeup, Ready.W)
- await ev.wait()
- loop.remove_reader(fd)
- loop.remove_writer(fd)
- else:
- raise e.InternalError("bad poll status: %s")
- fd, s = gen.send(ready) # type: ignore[misc]
- else:
- fd = item # type: ignore[assignment]
- s = next(gen) # type: ignore[assignment]
- while 1:
- ev.clear()
- if s == Wait.R:
- loop.add_reader(fd, wakeup, Ready.R)
- await ev.wait()
- loop.remove_reader(fd)
- elif s == Wait.W:
- loop.add_writer(fd, wakeup, Ready.W)
- await ev.wait()
- loop.remove_writer(fd)
- elif s == Wait.RW:
- loop.add_reader(fd, wakeup, Ready.R)
- loop.add_writer(fd, wakeup, Ready.W)
- await ev.wait()
- loop.remove_reader(fd)
- loop.remove_writer(fd)
- else:
- raise e.InternalError("bad poll status: %s")
- s = gen.send(ready) # type: ignore[arg-type,assignment]
+ fileno, s = next(gen)
+ while 1:
+ ev.clear()
+ if s == Wait.R:
+ loop.add_reader(fileno, wakeup, Ready.R)
+ await ev.wait()
+ loop.remove_reader(fileno)
+ elif s == Wait.W:
+ loop.add_writer(fileno, wakeup, Ready.W)
+ await ev.wait()
+ loop.remove_writer(fileno)
+ elif s == Wait.RW:
+ loop.add_reader(fileno, wakeup, Ready.R)
+ loop.add_writer(fileno, wakeup, Ready.W)
+ await ev.wait()
+ loop.remove_reader(fileno)
+ loop.remove_writer(fileno)
+ else:
+ raise e.InternalError("bad poll status: %s")
+ fileno, s = gen.send(ready)
except StopIteration as ex:
rv: RV = ex.args[0] if ex.args else None
cdef libpq.PGresult *pgres
cdef int cires, ibres
- # Start the generator by sending the connection fd, which won't change
- # during the query process.
- yield libpq.PQsocket(pgconn_ptr)
-
# Sending the query
while 1:
if libpq.PQflush(pgconn_ptr) == 0:
from psycopg3.generators import execute
+def execute_wait(pgconn):
+ return psycopg3.waiting.wait(execute(pgconn), pgconn.socket)
+
+
def test_send_query(pgconn):
# This test shows how to process an async query in all its glory
pgconn.nonblocking = 1
b"/* %s */ select pg_sleep(0.01); select 1 as foo;"
% (b"x" * 1_000_000)
)
- results = psycopg3.waiting.wait(execute(pgconn))
+ results = execute_wait(pgconn)
assert len(results) == 2
assert results[0].nfields == 1
def test_send_query_params(pgconn):
pgconn.send_query_params(b"select $1::int + $2", [b"5", b"3"])
- (res,) = psycopg3.waiting.wait(execute(pgconn))
+ (res,) = execute_wait(pgconn)
assert res.status == pq.ExecStatus.TUPLES_OK
assert res.get_value(0, 0) == b"8"
def test_send_prepare(pgconn):
pgconn.send_prepare(b"prep", b"select $1::int + $2::int")
- (res,) = psycopg3.waiting.wait(execute(pgconn))
+ (res,) = execute_wait(pgconn)
assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
pgconn.send_query_prepared(b"prep", [b"3", b"5"])
- (res,) = psycopg3.waiting.wait(execute(pgconn))
+ (res,) = execute_wait(pgconn)
assert res.get_value(0, 0) == b"8"
pgconn.finish()
def test_send_prepare_types(pgconn):
pgconn.send_prepare(b"prep", b"select $1 + $2", [23, 23])
- (res,) = psycopg3.waiting.wait(execute(pgconn))
+ (res,) = execute_wait(pgconn)
assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
pgconn.send_query_prepared(b"prep", [b"3", b"5"])
- (res,) = psycopg3.waiting.wait(execute(pgconn))
+ (res,) = execute_wait(pgconn)
assert res.get_value(0, 0) == b"8"
def test_send_prepared_binary_in(pgconn):
val = b"foo\00bar"
pgconn.send_prepare(b"", b"select length($1::bytea), length($2::bytea)")
- (res,) = psycopg3.waiting.wait(execute(pgconn))
+ (res,) = execute_wait(pgconn)
assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
pgconn.send_query_prepared(b"", [val, val], param_formats=[0, 1])
- (res,) = psycopg3.waiting.wait(execute(pgconn))
+ (res,) = execute_wait(pgconn)
assert res.status == pq.ExecStatus.TUPLES_OK
assert res.get_value(0, 0) == b"3"
assert res.get_value(0, 1) == b"7"
def test_send_prepared_binary_out(pgconn, fmt, out):
val = b"foo\00bar"
pgconn.send_prepare(b"", b"select $1::bytea")
- (res,) = psycopg3.waiting.wait(execute(pgconn))
+ (res,) = execute_wait(pgconn)
assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
pgconn.send_query_prepared(
b"", [val], param_formats=[1], result_format=fmt
)
- (res,) = psycopg3.waiting.wait(execute(pgconn))
+ (res,) = execute_wait(pgconn)
assert res.status == pq.ExecStatus.TUPLES_OK
assert res.get_value(0, 0) == out
assert pgconn.transaction_status == pq.TransactionStatus.INTRANS
pgconn.send_query(b"select 1")
assert pgconn.transaction_status == pq.TransactionStatus.ACTIVE
- psycopg3.waiting.wait(psycopg3.generators.execute(pgconn))
+ psycopg3.waiting.wait(psycopg3.generators.execute(pgconn), pgconn.socket)
assert pgconn.transaction_status == pq.TransactionStatus.INTRANS
pgconn.finish()
assert pgconn.transaction_status == pq.TransactionStatus.UNKNOWN
assert conn.closed
assert conn.pgconn.status == conn.ConnStatus.BAD
- with pytest.raises(psycopg3.InterfaceError):
+ with pytest.raises(psycopg3.OperationalError):
cur.execute("select 1")
assert aconn.closed
assert aconn.pgconn.status == aconn.ConnStatus.BAD
- with pytest.raises(psycopg3.InterfaceError):
+ with pytest.raises(psycopg3.OperationalError):
await cur.execute("select 1")
@pytest.fixture
def commands(conn, monkeypatch):
- """The queue of commands issued internally by the test connection."""
+ """The list of commands issued internally by the test connection."""
+ yield patch_exec(conn, monkeypatch)
+
+
+def patch_exec(conn, monkeypatch):
+ """Helper to implement the commands fixture both sync and async."""
_orig_exec_command = conn._exec_command
L = ListPopAll()
command = command.decode(conn.client_encoding)
L.insert(0, command)
- _orig_exec_command(command)
+ return _orig_exec_command(command)
monkeypatch.setattr(conn, "_exec_command", _exec_command)
- yield L
+ return L
def in_transaction(conn):
from psycopg3 import ProgrammingError, Rollback
from .test_transaction import in_transaction, insert_row, inserted
-from .test_transaction import ExpectedException, ListPopAll
+from .test_transaction import ExpectedException, patch_exec
from .test_transaction import create_test_table # noqa # autouse fixture
pytestmark = pytest.mark.asyncio
@pytest.fixture
-async def commands(aconn, monkeypatch):
- """The queue of commands issued internally by the test connection."""
- _orig_exec_command = aconn._exec_command
- L = ListPopAll()
-
- async def _exec_command(command):
- if isinstance(command, bytes):
- command = command.decode(aconn.client_encoding)
-
- L.insert(0, command)
- await _orig_exec_command(command)
-
- monkeypatch.setattr(aconn, "_exec_command", _exec_command)
- yield L
+def commands(aconn, monkeypatch):
+ """The list of commands issued internally by the test connection."""
+ yield patch_exec(aconn, monkeypatch)
async def test_basic(aconn):