Both sync and async versions are supported.
Fixes: #6842
Change-Id: I57751c5028acebfc6f9c43572562405453a2f2a4
This section covers new features and improvements in SQLAlchemy 2.0 which
are not otherwise part of the major 1.4->2.0 migration path.
+.. _ticket_6842:
+
+Dialect support for psycopg 3 (a.k.a. "psycopg")
+-------------------------------------------------
+
+Added dialect support for the `psycopg 3 <https://pypi.org/project/psycopg/>`_
+DBAPI, which despite the number "3" now goes by the package name ``psycopg``,
+superseding the previous ``psycopg2`` package that for the time being remains
+SQLAlchemy's "default" driver for the ``postgresql`` dialects. ``psycopg`` is a
+completely reworked and modernized database adapter for PostgreSQL which
+supports concepts such as prepared statements as well as Python asyncio.
+
+``psycopg`` is the first DBAPI supported by SQLAlchemy which provides
+both a pep-249 synchronous API as well as an asyncio driver. The same
+``psycopg`` database URL may be used with the :func:`_sa.create_engine`
+and :func:`_asyncio.create_async_engine` engine-creation functions, and the
+corresponding sync or asyncio version of the dialect will be selected
+automatically.
+
+.. seealso::
+
+ :ref:`postgresql_psycopg`
+
+
+
Behavioral Changes
==================
--- /dev/null
+.. change::
+ :tags: postgresql, dialect
+ :tickets: 6842
+
+ Added support for ``psycopg`` dialect supporting both sync and async
+ execution. This dialect is available under the ``postgresql+psycopg`` name
+ for both the :func:`_sa.create_engine` and
+ :func:`_asyncio.create_async_engine` engine-creation functions.
+
+ .. seealso::
+
+ :ref:`ticket_6842`
+
+ :ref:`postgresql_psycopg`
\ No newline at end of file
.. automodule:: sqlalchemy.dialects.postgresql.psycopg2
+.. _postgresql_psycopg:
+
+psycopg
+--------
+
+.. automodule:: sqlalchemy.dialects.postgresql.psycopg
+
pg8000
------
class AsyncAdapt_asyncmy_connection(AdaptedConnection):
await_ = staticmethod(await_only)
- __slots__ = ("dbapi", "_connection", "_execute_mutex")
+ __slots__ = ("dbapi", "_execute_mutex")
def __init__(self, dbapi, connection):
self.dbapi = dbapi
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+from types import ModuleType
+
from . import asyncpg # noqa
from . import base
from . import pg8000 # noqa
+from . import psycopg # noqa
from . import psycopg2 # noqa
from . import psycopg2cffi # noqa
from .array import All
from .ranges import TSTZRANGE
from ...util import compat
+# Alias psycopg also as psycopg_asnyc
+psycopg_async = type(
+ "psycopg_asnyc", (ModuleType,), {"dialect": psycopg.dialect_async}
+)
+
base.dialect = dialect = psycopg2.dialect
--- /dev/null
+import decimal
+
+from .array import ARRAY as PGARRAY
+from .base import _DECIMAL_TYPES
+from .base import _FLOAT_TYPES
+from .base import _INT_TYPES
+from .base import PGDialect
+from .base import PGExecutionContext
+from .base import UUID
+from .hstore import HSTORE
+from ... import exc
+from ... import processors
+from ... import types as sqltypes
+from ... import util
+
+_server_side_id = util.counter()
+
+
+class _PsycopgNumeric(sqltypes.Numeric):
+ def bind_processor(self, dialect):
+ return None
+
+ def result_processor(self, dialect, coltype):
+ if self.asdecimal:
+ if coltype in _FLOAT_TYPES:
+ return processors.to_decimal_processor_factory(
+ decimal.Decimal, self._effective_decimal_return_scale
+ )
+ elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
+ # psycopg returns Decimal natively for 1700
+ return None
+ else:
+ raise exc.InvalidRequestError(
+ "Unknown PG numeric type: %d" % coltype
+ )
+ else:
+ if coltype in _FLOAT_TYPES:
+ # psycopg returns float natively for 701
+ return None
+ elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
+ return processors.to_float
+ else:
+ raise exc.InvalidRequestError(
+ "Unknown PG numeric type: %d" % coltype
+ )
+
+
+class _PsycopgHStore(HSTORE):
+ def bind_processor(self, dialect):
+ if dialect._has_native_hstore:
+ return None
+ else:
+ return super(_PsycopgHStore, self).bind_processor(dialect)
+
+ def result_processor(self, dialect, coltype):
+ if dialect._has_native_hstore:
+ return None
+ else:
+ return super(_PsycopgHStore, self).result_processor(
+ dialect, coltype
+ )
+
+
+class _PsycopgUUID(UUID):
+ def bind_processor(self, dialect):
+ return None
+
+ def result_processor(self, dialect, coltype):
+ if not self.as_uuid and dialect.use_native_uuid:
+
+ def process(value):
+ if value is not None:
+ value = str(value)
+ return value
+
+ return process
+
+
+class _PsycopgARRAY(PGARRAY):
+ render_bind_cast = True
+
+
+class _PGExecutionContext_common_psycopg(PGExecutionContext):
+ def create_server_side_cursor(self):
+ # use server-side cursors:
+ # psycopg
+ # https://www.psycopg.org/psycopg3/docs/advanced/cursors.html#server-side-cursors
+ # psycopg2
+ # https://www.psycopg.org/docs/usage.html#server-side-cursors
+ ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:])
+ return self._dbapi_connection.cursor(ident)
+
+
+class _PGDialect_common_psycopg(PGDialect):
+ supports_statement_cache = True
+ supports_server_side_cursors = True
+
+ default_paramstyle = "pyformat"
+
+ _has_native_hstore = True
+
+ colspecs = util.update_copy(
+ PGDialect.colspecs,
+ {
+ sqltypes.Numeric: _PsycopgNumeric,
+ HSTORE: _PsycopgHStore,
+ UUID: _PsycopgUUID,
+ sqltypes.ARRAY: _PsycopgARRAY,
+ },
+ )
+
+ def __init__(
+ self,
+ client_encoding=None,
+ use_native_hstore=True,
+ use_native_uuid=True,
+ **kwargs
+ ):
+ PGDialect.__init__(self, **kwargs)
+ if not use_native_hstore:
+ self._has_native_hstore = False
+ self.use_native_hstore = use_native_hstore
+ self.use_native_uuid = use_native_uuid
+ self.client_encoding = client_encoding
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username="user", database="dbname")
+
+ is_multihost = False
+ if "host" in url.query:
+ is_multihost = isinstance(url.query["host"], (list, tuple))
+
+ if opts:
+ if "port" in opts:
+ opts["port"] = int(opts["port"])
+ opts.update(url.query)
+ if is_multihost:
+ opts["host"] = ",".join(url.query["host"])
+ # send individual dbname, user, password, host, port
+ # parameters to psycopg2.connect()
+ return ([], opts)
+ elif url.query:
+ # any other connection arguments, pass directly
+ opts.update(url.query)
+ if is_multihost:
+ opts["host"] = ",".join(url.query["host"])
+ return ([], opts)
+ else:
+ # no connection arguments whatsoever; psycopg2.connect()
+ # requires that "dsn" be present as a blank string.
+ return ([""], opts)
+
+ def get_isolation_level_values(self, dbapi_conn):
+ return (
+ "AUTOCOMMIT",
+ "READ COMMITTED",
+ "READ UNCOMMITTED",
+ "REPEATABLE READ",
+ "SERIALIZABLE",
+ )
+
+ def set_deferrable(self, connection, value):
+ connection.deferrable = value
+
+ def get_deferrable(self, connection):
+ return connection.deferrable
+
+ def _do_autocommit(self, connection, value):
+ connection.autocommit = value
+
+ def do_ping(self, dbapi_connection):
+ cursor = None
+ try:
+ self._do_autocommit(dbapi_connection, True)
+ cursor = dbapi_connection.cursor()
+ try:
+ cursor.execute(self._dialect_specific_select_one)
+ finally:
+ cursor.close()
+ if not dbapi_connection.closed:
+ self._do_autocommit(dbapi_connection, False)
+ except self.dbapi.Error as err:
+ if self.is_disconnect(err, dbapi_connection, cursor):
+ return False
+ else:
+ raise
+ else:
+ return True
class AsyncAdapt_asyncpg_connection(AdaptedConnection):
__slots__ = (
"dbapi",
- "_connection",
"isolation_level",
"_isolation_setting",
"readonly",
or as Python uuid objects.
The UUID type is currently known to work within the prominent DBAPI
- drivers supported by SQLAlchemy including psycopg2, pg8000 and
+ drivers supported by SQLAlchemy including psycopg, psycopg2, pg8000 and
asyncpg. Support for other DBAPI drivers may be incomplete or non-present.
"""
self.connection.execute(DropEnumType(enum))
+ def get_dbapi_type(self, dbapi):
+ """dont return dbapi.STRING for ENUM in PostgreSQL, since that's
+ a different type"""
+
+ return None
+
def _check_for_name_in_memos(self, checkfirst, kw):
"""Look in the 'ddl runner' for 'memos', then
note our name in that collection.
--- /dev/null
+# postgresql/psycopg2.py
+# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+r"""
+.. dialect:: postgresql+psycopg
+ :name: psycopg (a.k.a. psycopg 3)
+ :dbapi: psycopg
+ :connectstring: postgresql+psycopg://user:password@host:port/dbname[?key=value&key=value...]
+ :url: https://pypi.org/project/psycopg/
+
+``psycopg`` is the package and module name for version 3 of the ``psycopg``
+database driver, formerly known as ``psycopg2``. This driver is different
+enough from its ``psycopg2`` predecessor that SQLAlchemy supports it
+via a totally separate dialect; support for ``psycopg2`` is expected to remain
+for as long as that package continues to function for modern Python versions,
+and also remains the default dialect for the ``postgresql://`` dialect
+series.
+
+The SQLAlchemy ``psycopg`` dialect provides both a sync and an async
+implementation under the same dialect name. The proper version is
+selected depending on how the engine is created:
+
+* calling :func:`_sa.create_engine` with ``postgresql+psycopg://...`` will
+ automatically select the sync version, e.g.::
+
+ from sqlalchemy import create_engine
+ sync_engine = create_engine("postgresql+psycopg://scott:tiger@localhost/test")
+
+* calling :func:`_asyncio.create_async_engine` with
+ ``postgresql+psycopg://...`` will automatically select the async version,
+ e.g.::
+
+ from sqlalchemy.ext.asyncio import create_async_engine
+ asyncio_engine = create_async_engine("postgresql+psycopg://scott:tiger@localhost/test")
+
+The asyncio version of the dialect may also be specified explicitly using the
+``psycopg_async`` suffix, as::
+
+ from sqlalchemy.ext.asyncio import create_async_engine
+ asyncio_engine = create_async_engine("postgresql+psycopg_async://scott:tiger@localhost/test")
+
+The ``psycopg`` dialect has the same API features as that of ``psycopg2``,
+with the exeption of the "fast executemany" helpers. The "fast executemany"
+helpers are expected to be generalized and ported to ``psycopg`` before the final
+release of SQLAlchemy 2.0, however.
+
+
+.. seealso::
+
+ :ref:`postgresql_psycopg2` - The SQLAlchemy ``psycopg``
+ dialect shares most of its behavior with the ``psycopg2`` dialect.
+ Further documentation is available there.
+
+""" # noqa
+import logging
+import re
+
+from ._psycopg_common import _PGDialect_common_psycopg
+from ._psycopg_common import _PGExecutionContext_common_psycopg
+from ._psycopg_common import _PsycopgUUID
+from .base import INTERVAL
+from .base import PGCompiler
+from .base import PGIdentifierPreparer
+from .base import UUID
+from .json import JSON
+from .json import JSONB
+from .json import JSONPathType
+from ... import pool
+from ... import types as sqltypes
+from ... import util
+from ...engine import AdaptedConnection
+from ...util.concurrency import await_fallback
+from ...util.concurrency import await_only
+
+logger = logging.getLogger("sqlalchemy.dialects.postgresql")
+
+
+class _PGString(sqltypes.String):
+ render_bind_cast = True
+
+
+class _PGJSON(JSON):
+ render_bind_cast = True
+
+ def bind_processor(self, dialect):
+ return self._make_bind_processor(None, dialect._psycopg_Json)
+
+ def result_processor(self, dialect, coltype):
+ return None
+
+
+class _PGJSONB(JSONB):
+ render_bind_cast = True
+
+ def bind_processor(self, dialect):
+ return self._make_bind_processor(None, dialect._psycopg_Jsonb)
+
+ def result_processor(self, dialect, coltype):
+ return None
+
+
+class _PGJSONIntIndexType(sqltypes.JSON.JSONIntIndexType):
+ __visit_name__ = "json_int_index"
+
+ render_bind_cast = True
+
+
+class _PGJSONStrIndexType(sqltypes.JSON.JSONStrIndexType):
+ __visit_name__ = "json_str_index"
+
+ render_bind_cast = True
+
+
+class _PGJSONPathType(JSONPathType):
+ pass
+
+
+class _PGUUID(_PsycopgUUID):
+ render_bind_cast = True
+
+
+class _PGInterval(INTERVAL):
+ render_bind_cast = True
+
+
+class _PGTimeStamp(sqltypes.DateTime):
+ render_bind_cast = True
+
+
+class _PGDate(sqltypes.Date):
+ render_bind_cast = True
+
+
+class _PGTime(sqltypes.Time):
+ render_bind_cast = True
+
+
+class _PGInteger(sqltypes.Integer):
+ render_bind_cast = True
+
+
+class _PGSmallInteger(sqltypes.SmallInteger):
+ render_bind_cast = True
+
+
+class _PGNullType(sqltypes.NullType):
+ render_bind_cast = True
+
+
+class _PGBigInteger(sqltypes.BigInteger):
+ render_bind_cast = True
+
+
+class _PGBoolean(sqltypes.Boolean):
+ render_bind_cast = True
+
+
+class PGExecutionContext_psycopg(_PGExecutionContext_common_psycopg):
+ pass
+
+
+class PGCompiler_psycopg(PGCompiler):
+ pass
+
+
+class PGIdentifierPreparer_psycopg(PGIdentifierPreparer):
+ pass
+
+
+def _log_notices(diagnostic):
+ logger.info("%s: %s", diagnostic.severity, diagnostic.message_primary)
+
+
+class PGDialect_psycopg(_PGDialect_common_psycopg):
+ driver = "psycopg"
+
+ supports_statement_cache = True
+ supports_server_side_cursors = True
+ default_paramstyle = "pyformat"
+ supports_sane_multi_rowcount = True
+
+ execution_ctx_cls = PGExecutionContext_psycopg
+ statement_compiler = PGCompiler_psycopg
+ preparer = PGIdentifierPreparer_psycopg
+ psycopg_version = (0, 0)
+
+ _has_native_hstore = True
+
+ colspecs = util.update_copy(
+ _PGDialect_common_psycopg.colspecs,
+ {
+ sqltypes.String: _PGString,
+ JSON: _PGJSON,
+ sqltypes.JSON: _PGJSON,
+ JSONB: _PGJSONB,
+ sqltypes.JSON.JSONPathType: _PGJSONPathType,
+ sqltypes.JSON.JSONIntIndexType: _PGJSONIntIndexType,
+ sqltypes.JSON.JSONStrIndexType: _PGJSONStrIndexType,
+ UUID: _PGUUID,
+ sqltypes.Interval: _PGInterval,
+ INTERVAL: _PGInterval,
+ sqltypes.Date: _PGDate,
+ sqltypes.DateTime: _PGTimeStamp,
+ sqltypes.Time: _PGTime,
+ sqltypes.Integer: _PGInteger,
+ sqltypes.SmallInteger: _PGSmallInteger,
+ sqltypes.BigInteger: _PGBigInteger,
+ },
+ )
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ if self.dbapi:
+ m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__)
+ if m:
+ self.psycopg_version = tuple(
+ int(x) for x in m.group(1, 2, 3) if x is not None
+ )
+
+ if self.psycopg_version < (3, 0, 2):
+ raise ImportError(
+ "psycopg version 3.0.2 or higher is required."
+ )
+
+ from psycopg.adapt import AdaptersMap
+
+ self._psycopg_adapters_map = adapters_map = AdaptersMap(
+ self.dbapi.adapters
+ )
+
+ if self._json_deserializer:
+ from psycopg.types.json import set_json_loads
+
+ set_json_loads(self._json_deserializer, adapters_map)
+
+ if self._json_serializer:
+ from psycopg.types.json import set_json_dumps
+
+ set_json_dumps(self._json_serializer, adapters_map)
+
+ def create_connect_args(self, url):
+ # see https://github.com/psycopg/psycopg/issues/83
+ cargs, cparams = super().create_connect_args(url)
+
+ cparams["context"] = self._psycopg_adapters_map
+ if self.client_encoding is not None:
+ cparams["client_encoding"] = self.client_encoding
+ return cargs, cparams
+
+ def _type_info_fetch(self, connection, name):
+ from psycopg.types import TypeInfo
+
+ return TypeInfo.fetch(connection.connection, name)
+
+ def initialize(self, connection):
+ super().initialize(connection)
+
+ # PGDialect.initialize() checks server version for <= 8.2 and sets
+ # this flag to False if so
+ if not self.full_returning:
+ self.insert_executemany_returning = False
+
+ # HSTORE can't be registered until we have a connection so that
+ # we can look up its OID, so we set up this adapter in
+ # initialize()
+ if self.use_native_hstore:
+ info = self._type_info_fetch(connection, "hstore")
+ self._has_native_hstore = info is not None
+ if self._has_native_hstore:
+ from psycopg.types.hstore import register_hstore
+
+ # register the adapter for connections made subsequent to
+ # this one
+ register_hstore(info, self._psycopg_adapters_map)
+
+ # register the adapter for this connection
+ register_hstore(info, connection.connection)
+
+ @classmethod
+ def dbapi(cls):
+ import psycopg
+
+ return psycopg
+
+ @classmethod
+ def get_async_dialect_cls(cls, url):
+ return PGDialectAsync_psycopg
+
+ @util.memoized_property
+ def _isolation_lookup(self):
+ return {
+ "READ COMMITTED": self.dbapi.IsolationLevel.READ_COMMITTED,
+ "READ UNCOMMITTED": self.dbapi.IsolationLevel.READ_UNCOMMITTED,
+ "REPEATABLE READ": self.dbapi.IsolationLevel.REPEATABLE_READ,
+ "SERIALIZABLE": self.dbapi.IsolationLevel.SERIALIZABLE,
+ }
+
+ @util.memoized_property
+ def _psycopg_Json(self):
+ from psycopg.types import json
+
+ return json.Json
+
+ @util.memoized_property
+ def _psycopg_Jsonb(self):
+ from psycopg.types import json
+
+ return json.Jsonb
+
+ @util.memoized_property
+ def _psycopg_TransactionStatus(self):
+ from psycopg.pq import TransactionStatus
+
+ return TransactionStatus
+
+ def _do_isolation_level(self, connection, autocommit, isolation_level):
+ connection.autocommit = autocommit
+ connection.isolation_level = isolation_level
+
+ def get_isolation_level(self, dbapi_connection):
+ if hasattr(dbapi_connection, "dbapi_connection"):
+ dbapi_connection = dbapi_connection.dbapi_connection
+
+ status_before = dbapi_connection.info.transaction_status
+ value = super().get_isolation_level(dbapi_connection)
+
+ # don't rely on psycopg providing enum symbols, compare with
+ # eq/ne
+ if status_before == self._psycopg_TransactionStatus.IDLE:
+ dbapi_connection.rollback()
+ return value
+
+ def set_isolation_level(self, connection, level):
+ connection = getattr(connection, "dbapi_connection", connection)
+ if level == "AUTOCOMMIT":
+ self._do_isolation_level(
+ connection, autocommit=True, isolation_level=None
+ )
+ else:
+ self._do_isolation_level(
+ connection,
+ autocommit=False,
+ isolation_level=self._isolation_lookup[level],
+ )
+
+ def set_readonly(self, connection, value):
+ connection.read_only = value
+
+ def get_readonly(self, connection):
+ return connection.read_only
+
+ def on_connect(self):
+ def notices(conn):
+ conn.add_notice_handler(_log_notices)
+
+ fns = [notices]
+
+ if self.isolation_level is not None:
+
+ def on_connect(conn):
+ self.set_isolation_level(conn, self.isolation_level)
+
+ fns.append(on_connect)
+
+ # fns always has the notices function
+ def on_connect(conn):
+ for fn in fns:
+ fn(conn)
+
+ return on_connect
+
+ def is_disconnect(self, e, connection, cursor):
+ if isinstance(e, self.dbapi.Error) and connection is not None:
+ if connection.closed or connection.broken:
+ return True
+ return False
+
+ def _do_prepared_twophase(self, connection, command, recover=False):
+ dbapi_conn = connection.connection.dbapi_connection
+ if (
+ recover
+ # don't rely on psycopg providing enum symbols, compare with
+ # eq/ne
+ or dbapi_conn.info.transaction_status
+ != self._psycopg_TransactionStatus.IDLE
+ ):
+ dbapi_conn.rollback()
+ before = dbapi_conn.autocommit
+ try:
+ self._do_autocommit(dbapi_conn, True)
+ dbapi_conn.execute(command)
+ finally:
+ self._do_autocommit(dbapi_conn, before)
+
+ def do_rollback_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ if is_prepared:
+ self._do_prepared_twophase(
+ connection, f"ROLLBACK PREPARED '{xid}'", recover=recover
+ )
+ else:
+ self.do_rollback(connection.connection)
+
+ def do_commit_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ if is_prepared:
+ self._do_prepared_twophase(
+ connection, f"COMMIT PREPARED '{xid}'", recover=recover
+ )
+ else:
+ self.do_commit(connection.connection)
+
+
+class AsyncAdapt_psycopg_cursor:
+ __slots__ = ("_cursor", "await_", "_rows")
+
+ _psycopg_ExecStatus = None
+
+ def __init__(self, cursor, await_) -> None:
+ self._cursor = cursor
+ self.await_ = await_
+ self._rows = []
+
+ def __getattr__(self, name):
+ return getattr(self._cursor, name)
+
+ @property
+ def arraysize(self):
+ return self._cursor.arraysize
+
+ @arraysize.setter
+ def arraysize(self, value):
+ self._cursor.arraysize = value
+
+ def close(self):
+ self._rows.clear()
+ # Normal cursor just call _close() in a non-sync way.
+ self._cursor._close()
+
+ def execute(self, query, params=None, **kw):
+ result = self.await_(self._cursor.execute(query, params, **kw))
+ # sqlalchemy result is not async, so need to pull all rows here
+ res = self._cursor.pgresult
+
+ # don't rely on psycopg providing enum symbols, compare with
+ # eq/ne
+ if res and res.status == self._psycopg_ExecStatus.TUPLES_OK:
+ rows = self.await_(self._cursor.fetchall())
+ if not isinstance(rows, list):
+ self._rows = list(rows)
+ else:
+ self._rows = rows
+ return result
+
+ def executemany(self, query, params_seq):
+ return self.await_(self._cursor.executemany(query, params_seq))
+
+ def __iter__(self):
+ # TODO: try to avoid pop(0) on a list
+ while self._rows:
+ yield self._rows.pop(0)
+
+ def fetchone(self):
+ if self._rows:
+ # TODO: try to avoid pop(0) on a list
+ return self._rows.pop(0)
+ else:
+ return None
+
+ def fetchmany(self, size=None):
+ if size is None:
+ size = self._cursor.arraysize
+
+ retval = self._rows[0:size]
+ self._rows = self._rows[size:]
+ return retval
+
+ def fetchall(self):
+ retval = self._rows
+ self._rows = []
+ return retval
+
+
+class AsyncAdapt_psycopg_ss_cursor(AsyncAdapt_psycopg_cursor):
+ def execute(self, query, params=None, **kw):
+ self.await_(self._cursor.execute(query, params, **kw))
+ return self
+
+ def close(self):
+ self.await_(self._cursor.close())
+
+ def fetchone(self):
+ return self.await_(self._cursor.fetchone())
+
+ def fetchmany(self, size=0):
+ return self.await_(self._cursor.fetchmany(size))
+
+ def fetchall(self):
+ return self.await_(self._cursor.fetchall())
+
+ def __iter__(self):
+ iterator = self._cursor.__aiter__()
+ while True:
+ try:
+ yield self.await_(iterator.__anext__())
+ except StopAsyncIteration:
+ break
+
+
+class AsyncAdapt_psycopg_connection(AdaptedConnection):
+ __slots__ = ()
+ await_ = staticmethod(await_only)
+
+ def __init__(self, connection) -> None:
+ self._connection = connection
+
+ def __getattr__(self, name):
+ return getattr(self._connection, name)
+
+ def execute(self, query, params=None, **kw):
+ cursor = self.await_(self._connection.execute(query, params, **kw))
+ return AsyncAdapt_psycopg_cursor(cursor, self.await_)
+
+ def cursor(self, *args, **kw):
+ cursor = self._connection.cursor(*args, **kw)
+ if hasattr(cursor, "name"):
+ return AsyncAdapt_psycopg_ss_cursor(cursor, self.await_)
+ else:
+ return AsyncAdapt_psycopg_cursor(cursor, self.await_)
+
+ def commit(self):
+ self.await_(self._connection.commit())
+
+ def rollback(self):
+ self.await_(self._connection.rollback())
+
+ def close(self):
+ self.await_(self._connection.close())
+
+ @property
+ def autocommit(self):
+ return self._connection.autocommit
+
+ @autocommit.setter
+ def autocommit(self, value):
+ self.set_autocommit(value)
+
+ def set_autocommit(self, value):
+ self.await_(self._connection.set_autocommit(value))
+
+ def set_isolation_level(self, value):
+ self.await_(self._connection.set_isolation_level(value))
+
+ def set_read_only(self, value):
+ self.await_(self._connection.set_read_only(value))
+
+ def set_deferrable(self, value):
+ self.await_(self._connection.set_deferrable(value))
+
+
+class AsyncAdaptFallback_psycopg_connection(AsyncAdapt_psycopg_connection):
+ __slots__ = ()
+ await_ = staticmethod(await_fallback)
+
+
+class PsycopgAdaptDBAPI:
+ def __init__(self, psycopg) -> None:
+ self.psycopg = psycopg
+
+ for k, v in self.psycopg.__dict__.items():
+ if k != "connect":
+ self.__dict__[k] = v
+
+ def connect(self, *arg, **kw):
+ async_fallback = kw.pop("async_fallback", False)
+ if util.asbool(async_fallback):
+ return AsyncAdaptFallback_psycopg_connection(
+ await_fallback(
+ self.psycopg.AsyncConnection.connect(*arg, **kw)
+ )
+ )
+ else:
+ return AsyncAdapt_psycopg_connection(
+ await_only(self.psycopg.AsyncConnection.connect(*arg, **kw))
+ )
+
+
+class PGDialectAsync_psycopg(PGDialect_psycopg):
+ is_async = True
+ supports_statement_cache = True
+
+ @classmethod
+ def dbapi(cls):
+ import psycopg
+ from psycopg.pq import ExecStatus
+
+ AsyncAdapt_psycopg_cursor._psycopg_ExecStatus = ExecStatus
+
+ return PsycopgAdaptDBAPI(psycopg)
+
+ @classmethod
+ def get_pool_class(cls, url):
+
+ async_fallback = url.query.get("async_fallback", False)
+
+ if util.asbool(async_fallback):
+ return pool.FallbackAsyncAdaptedQueuePool
+ else:
+ return pool.AsyncAdaptedQueuePool
+
+ def _type_info_fetch(self, connection, name):
+ from psycopg.types import TypeInfo
+
+ adapted = connection.connection
+ return adapted.await_(TypeInfo.fetch(adapted._connection, name))
+
+ def _do_isolation_level(self, connection, autocommit, isolation_level):
+ connection.set_autocommit(autocommit)
+ connection.set_isolation_level(isolation_level)
+
+ def _do_autocommit(self, connection, value):
+ connection.set_autocommit(value)
+
+ def set_readonly(self, connection, value):
+ connection.set_read_only(value)
+
+ def set_deferrable(self, connection, value):
+ connection.set_deferrable(value)
+
+ def get_driver_connection(self, connection):
+ return connection._connection
+
+
+dialect = PGDialect_psycopg
+dialect_async = PGDialectAsync_psycopg
:connectstring: postgresql+psycopg2://user:password@host:port/dbname[?key=value&key=value...]
:url: https://pypi.org/project/psycopg2/
+.. _psycopg2_toplevel:
+
psycopg2 Connect Arguments
--------------------------
""" # noqa
import collections.abc as collections_abc
-import decimal
import logging
import re
-from uuid import UUID as _python_UUID
-from .array import ARRAY as PGARRAY
-from .base import _DECIMAL_TYPES
-from .base import _FLOAT_TYPES
-from .base import _INT_TYPES
+from ._psycopg_common import _PGDialect_common_psycopg
+from ._psycopg_common import _PGExecutionContext_common_psycopg
from .base import PGCompiler
-from .base import PGDialect
-from .base import PGExecutionContext
from .base import PGIdentifierPreparer
-from .base import UUID
-from .hstore import HSTORE
from .json import JSON
from .json import JSONB
-from ... import exc
-from ... import processors
from ... import types as sqltypes
from ... import util
from ...engine import cursor as _cursor
logger = logging.getLogger("sqlalchemy.dialects.postgresql")
-class _PGNumeric(sqltypes.Numeric):
- def bind_processor(self, dialect):
- return None
-
- def result_processor(self, dialect, coltype):
- if self.asdecimal:
- if coltype in _FLOAT_TYPES:
- return processors.to_decimal_processor_factory(
- decimal.Decimal, self._effective_decimal_return_scale
- )
- elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
- # pg8000 returns Decimal natively for 1700
- return None
- else:
- raise exc.InvalidRequestError(
- "Unknown PG numeric type: %d" % coltype
- )
- else:
- if coltype in _FLOAT_TYPES:
- # pg8000 returns float natively for 701
- return None
- elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
- return processors.to_float
- else:
- raise exc.InvalidRequestError(
- "Unknown PG numeric type: %d" % coltype
- )
-
-
-class _PGHStore(HSTORE):
- def bind_processor(self, dialect):
- if dialect._has_native_hstore:
- return None
- else:
- return super(_PGHStore, self).bind_processor(dialect)
-
- def result_processor(self, dialect, coltype):
- if dialect._has_native_hstore:
- return None
- else:
- return super(_PGHStore, self).result_processor(dialect, coltype)
-
-
-class _PGARRAY(PGARRAY):
- render_bind_cast = True
-
-
class _PGJSON(JSON):
def result_processor(self, dialect, coltype):
return None
return None
-class _PGUUID(UUID):
- def bind_processor(self, dialect):
- if not self.as_uuid and dialect.use_native_uuid:
-
- def process(value):
- if value is not None:
- value = _python_UUID(value)
- return value
-
- return process
-
- def result_processor(self, dialect, coltype):
- if not self.as_uuid and dialect.use_native_uuid:
-
- def process(value):
- if value is not None:
- value = str(value)
- return value
-
- return process
-
-
-_server_side_id = util.counter()
-
-
-class PGExecutionContext_psycopg2(PGExecutionContext):
+class PGExecutionContext_psycopg2(_PGExecutionContext_common_psycopg):
_psycopg2_fetched_rows = None
- def create_server_side_cursor(self):
- # use server-side cursors:
- # https://lists.initd.org/pipermail/psycopg/2007-January/005251.html
- ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:])
- return self._dbapi_connection.cursor(ident)
-
def post_exec(self):
if (
self._psycopg2_fetched_rows
)
-class PGDialect_psycopg2(PGDialect):
+class PGDialect_psycopg2(_PGDialect_common_psycopg):
driver = "psycopg2"
supports_statement_cache = True
_has_native_hstore = True
colspecs = util.update_copy(
- PGDialect.colspecs,
+ _PGDialect_common_psycopg.colspecs,
{
- sqltypes.Numeric: _PGNumeric,
- HSTORE: _PGHStore,
JSON: _PGJSON,
sqltypes.JSON: _PGJSON,
JSONB: _PGJSONB,
- UUID: _PGUUID,
- sqltypes.ARRAY: _PGARRAY,
},
)
def __init__(
self,
- client_encoding=None,
- use_native_hstore=True,
- use_native_uuid=True,
executemany_mode="values_only",
executemany_batch_page_size=100,
executemany_values_page_size=1000,
**kwargs
):
- PGDialect.__init__(self, **kwargs)
- if not use_native_hstore:
- self._has_native_hstore = False
- self.use_native_hstore = use_native_hstore
- self.use_native_uuid = use_native_uuid
- self.client_encoding = client_encoding
+ _PGDialect_common_psycopg.__init__(self, **kwargs)
# Parse executemany_mode argument, allowing it to be only one of the
# symbol names
"SERIALIZABLE": extensions.ISOLATION_LEVEL_SERIALIZABLE,
}
- def get_isolation_level_values(self, dbapi_conn):
- return list(self._isolation_lookup)
-
def set_isolation_level(self, connection, level):
connection.set_isolation_level(self._isolation_lookup[level])
def get_deferrable(self, connection):
return connection.deferrable
- def do_ping(self, dbapi_connection):
- cursor = None
- try:
- dbapi_connection.autocommit = True
- cursor = dbapi_connection.cursor()
- try:
- cursor.execute(self._dialect_specific_select_one)
- finally:
- cursor.close()
- if not dbapi_connection.closed:
- dbapi_connection.autocommit = False
- except self.dbapi.Error as err:
- if self.is_disconnect(err, dbapi_connection, cursor):
- return False
- else:
- raise
- else:
- return True
-
def on_connect(self):
extras = self._psycopg2_extras
else:
return None
- def create_connect_args(self, url):
- opts = url.translate_connect_args(username="user")
-
- is_multihost = False
- if "host" in url.query:
- is_multihost = isinstance(url.query["host"], (list, tuple))
-
- if opts:
- if "port" in opts:
- opts["port"] = int(opts["port"])
- opts.update(url.query)
- if is_multihost:
- opts["host"] = ",".join(url.query["host"])
- # send individual dbname, user, password, host, port
- # parameters to psycopg2.connect()
- return ([], opts)
- elif url.query:
- # any other connection arguments, pass directly
- opts.update(url.query)
- if is_multihost:
- opts["host"] = ",".join(url.query["host"])
- return ([], opts)
- else:
- # no connection arguments whatsoever; psycopg2.connect()
- # requires that "dsn" be present as a blank string.
- return ([""], opts)
-
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.Error):
# check the "closed" flag. this might not be
u, plugins, kwargs = u._instantiate_plugins(kwargs)
entrypoint = u._get_entrypoint()
- dialect_cls = entrypoint.get_dialect_cls(u)
+ _is_async = kwargs.pop("_is_async", False)
+ if _is_async:
+ dialect_cls = entrypoint.get_async_dialect_cls(u)
+ else:
+ dialect_cls = entrypoint.get_dialect_cls(u)
if kwargs.pop("_coerce_config", False):
"""
+
if (not hard and self._soft_closed) or (hard and self.closed):
return
result.out_parameters = out_parameters
def _setup_dml_or_text_result(self):
+
if self.isinsert:
if self.compiled.postfetch_lastrowid:
self.inserted_primary_key_rows = (
# assert not result.returns_rows
elif self.isupdate and self._is_implicit_returning:
+ # get rowcount
+ # (which requires open cursor on some drivers)
+ # we were not doing this in 1.4, however
+ # test_rowcount -> test_update_rowcount_return_defaults
+ # is testing this, and psycopg will no longer return
+ # rowcount after cursor is closed.
+ result.rowcount
+
row = result.fetchone()
self.returned_default_rows = [row]
+
result._soft_close()
# test that it has a cursor metadata that is accurate.
dialect = self.dialect
+ # all of the rest of this... cython?
+
if dialect._has_events:
inputsizes = dict(inputsizes)
dialect.dispatch.do_setinputsizes(
"""
return cls
+ @classmethod
+ def get_async_dialect_cls(cls, url):
+ """Given a URL, return the :class:`.Dialect` that will be used by
+ an async engine.
+
+ By default this is an alias of :meth:`.Dialect.get_dialect_cls` and
+ just returns the cls. It may be used if a dialect provides
+ both a sync and async version under the same name, like the
+ ``psycopg`` driver.
+
+ .. versionadded:: 2
+
+ .. seealso::
+
+ :meth:`.Dialect.get_dialect_cls`
+
+ """
+ return cls.get_dialect_cls(url)
+
@classmethod
def load_provisioning(cls):
"""set up the provision.py module for this dialect.
else:
return cls
- def get_dialect(self):
+ def get_dialect(self, _is_async=False):
"""Return the SQLAlchemy :class:`_engine.Dialect` class corresponding
to this URL's driver name.
"""
entrypoint = self._get_entrypoint()
- dialect_cls = entrypoint.get_dialect_cls(self)
+ if _is_async:
+ dialect_cls = entrypoint.get_async_dialect_cls(self)
+ else:
+ dialect_cls = entrypoint.get_dialect_cls(self)
return dialect_cls
def translate_connect_args(self, names=None, **kw):
"streaming result set"
)
kw["future"] = True
+ kw["_is_async"] = True
sync_engine = _create_engine(*arg, **kw)
return AsyncEngine(sync_engine)
def _str_impl(self):
return String()
- def bind_processor(self, dialect):
- string_process = self._str_impl.bind_processor(dialect)
+ def _make_bind_processor(self, string_process, json_serializer):
+ if string_process:
- json_serializer = dialect._json_serializer or json.dumps
+ def process(value):
+ if value is self.NULL:
+ value = None
+ elif isinstance(value, elements.Null) or (
+ value is None and self.none_as_null
+ ):
+ return None
- def process(value):
- if value is self.NULL:
- value = None
- elif isinstance(value, elements.Null) or (
- value is None and self.none_as_null
- ):
- return None
+ serialized = json_serializer(value)
+ return string_process(serialized)
- serialized = json_serializer(value)
- if string_process:
- serialized = string_process(serialized)
- return serialized
+ else:
+
+ def process(value):
+ if value is self.NULL:
+ value = None
+ elif isinstance(value, elements.Null) or (
+ value is None and self.none_as_null
+ ):
+ return None
+
+ return json_serializer(value)
return process
+ def bind_processor(self, dialect):
+ string_process = self._str_impl.bind_processor(dialect)
+ json_serializer = dialect._json_serializer or json.dumps
+
+ return self._make_bind_processor(string_process, json_serializer)
+
def result_processor(self, dialect, coltype):
string_process = self._str_impl.result_processor(dialect, coltype)
json_deserializer = dialect._json_deserializer or json.loads
_configs = set()
def _set_name(self, db):
+ suffix = "_async" if db.dialect.is_async else ""
if db.dialect.server_version_info:
svi = ".".join(str(tok) for tok in db.dialect.server_version_info)
- self.name = "%s+%s_[%s]" % (db.name, db.driver, svi)
+ self.name = "%s+%s%s_[%s]" % (db.name, db.driver, suffix, svi)
else:
- self.name = "%s+%s" % (db.name, db.driver)
+ self.name = "%s+%s%s" % (db.name, db.driver, suffix)
@classmethod
def register(cls, db, db_opts, options, file_config):
)
if not all_configs:
- msg = "'%s' unsupported on any DB implementation %s%s" % (
+ msg = "'%s.%s' unsupported on any DB implementation %s%s" % (
+ cls.__module__,
cls.__name__,
", ".join(
"'%s(%s)+%s'"
sequence. This should be false only for oracle.
"""
return exclusions.open()
+
+ @property
+ def generic_classes(self):
+ "If X[Y] can be implemented with ``__class_getitem__``. py3.7+"
+ return exclusions.open()
+
+ @property
+ def json_deserializer_binary(self):
+ "indicates if the json_deserializer function is called with bytes"
+ return exclusions.closed()
return cursor.server_side
elif self.engine.dialect.driver == "pg8000":
return getattr(cursor, "server_side", False)
+ elif self.engine.dialect.driver == "psycopg":
+ return bool(getattr(cursor, "name", False))
else:
return False
eq_(row, (data_element,))
eq_(js.mock_calls, [mock.call(data_element)])
- eq_(jd.mock_calls, [mock.call(json.dumps(data_element))])
+ if testing.requires.json_deserializer_binary.enabled:
+ eq_(
+ jd.mock_calls,
+ [mock.call(json.dumps(data_element).encode())],
+ )
+ else:
+ eq_(jd.mock_calls, [mock.call(json.dumps(data_element))])
@testing.combinations(
("parameters",),
asyncpg;python_version>="3"
postgresql_psycopg2binary = psycopg2-binary
postgresql_psycopg2cffi = psycopg2cffi
+postgresql_psycopg = psycopg>=3.0.2
pymysql =
pymysql;python_version>="3"
pymysql<1;python_version<"3"
aiosqlite_file = sqlite+aiosqlite:///async_querytest.db
pysqlcipher_file = sqlite+pysqlcipher://:test@/querytest.db.enc
postgresql = postgresql+psycopg2://scott:tiger@127.0.0.1:5432/test
+psycopg2 = postgresql+psycopg2://scott:tiger@127.0.0.1:5432/test
+psycopg = postgresql+psycopg://scott:tiger@127.0.0.1:5432/test
+psycopg_async = postgresql+psycopg_async://scott:tiger@127.0.0.1:5432/test
+psycopg_async_fallback = postgresql+psycopg_async://scott:tiger@127.0.0.1:5432/test?async_fallback=true
asyncpg = postgresql+asyncpg://scott:tiger@127.0.0.1:5432/test
asyncpg_fallback = postgresql+asyncpg://scott:tiger@127.0.0.1:5432/test?async_fallback=true
pg8000 = postgresql+pg8000://scott:tiger@127.0.0.1:5432/test
from sqlalchemy import bindparam
from sqlalchemy import cast
from sqlalchemy import Column
+from sqlalchemy import create_engine
from sqlalchemy import DateTime
from sqlalchemy import DDL
from sqlalchemy import event
from sqlalchemy import TypeDecorator
from sqlalchemy import util
from sqlalchemy.dialects.postgresql import base as postgresql
+from sqlalchemy.dialects.postgresql import HSTORE
+from sqlalchemy.dialects.postgresql import JSONB
+from sqlalchemy.dialects.postgresql import psycopg as psycopg_dialect
from sqlalchemy.dialects.postgresql import psycopg2 as psycopg2_dialect
from sqlalchemy.dialects.postgresql.psycopg2 import EXECUTEMANY_BATCH
from sqlalchemy.dialects.postgresql.psycopg2 import EXECUTEMANY_PLAIN
if testing.against("postgresql+pg8000"):
# TODO: is there another way we're supposed to see this?
eq_(errmsg.orig.args[0]["C"], "23505")
- else:
+ elif not testing.against("postgresql+psycopg"):
eq_(errmsg.orig.pgcode, "23505")
- if testing.against("postgresql+asyncpg"):
+ if testing.against("postgresql+asyncpg") or testing.against(
+ "postgresql+psycopg"
+ ):
eq_(errmsg.orig.sqlstate, "23505")
".".join(str(x) for x in v)
)
+ @testing.only_on("postgresql+psycopg")
+ def test_psycopg_version(self):
+ v = testing.db.dialect.psycopg_version
+ assert testing.db.dialect.dbapi.__version__.startswith(
+ ".".join(str(x) for x in v)
+ )
+
@testing.combinations(
((8, 1), False, False),
((8, 1), None, False),
with testing.db.connect().execution_options(
isolation_level="SERIALIZABLE"
) as conn:
+
dbapi_conn = conn.connection.dbapi_connection
is_false(dbapi_conn.autocommit)
dbapi_conn.rollback()
eq_(val, "off")
- @testing.requires.psycopg2_compatibility
- def test_psycopg2_non_standard_err(self):
+ @testing.requires.psycopg_compatibility
+ def test_psycopg_non_standard_err(self):
# note that psycopg2 is sometimes called psycopg2cffi
# depending on platform
- psycopg2 = testing.db.dialect.dbapi
- TransactionRollbackError = __import__(
- "%s.extensions" % psycopg2.__name__
- ).extensions.TransactionRollbackError
+ psycopg = testing.db.dialect.dbapi
+ if psycopg.__version__.startswith("3"):
+ TransactionRollbackError = __import__(
+ "%s.errors" % psycopg.__name__
+ ).errors.TransactionRollback
+ else:
+ TransactionRollbackError = __import__(
+ "%s.extensions" % psycopg.__name__
+ ).extensions.TransactionRollbackError
exception = exc.DBAPIError.instance(
"some statement",
{},
TransactionRollbackError("foo"),
- psycopg2.Error,
+ psycopg.Error,
)
assert isinstance(exception, exc.OperationalError)
@testing.requires.no_coverage
- @testing.requires.psycopg2_compatibility
+ @testing.requires.psycopg_compatibility
def test_notice_logging(self):
log = logging.getLogger("sqlalchemy.dialects.postgresql")
buf = logging.handlers.BufferingHandler(100)
finally:
log.removeHandler(buf)
log.setLevel(lev)
- msgs = " ".join(b.msg for b in buf.buffer)
+ msgs = " ".join(b.getMessage() for b in buf.buffer)
eq_regex(
msgs,
- "NOTICE: notice: hi there(\nCONTEXT: .*?)? "
- "NOTICE: notice: another note(\nCONTEXT: .*?)?",
+ "NOTICE: [ ]?notice: hi there(\nCONTEXT: .*?)? "
+ "NOTICE: [ ]?notice: another note(\nCONTEXT: .*?)?",
)
- @testing.requires.psycopg2_or_pg8000_compatibility
+ @testing.requires.psycopg_or_pg8000_compatibility
@engines.close_open_connections
def test_client_encoding(self):
c = testing.db.connect()
new_encoding = c.exec_driver_sql("show client_encoding").fetchone()[0]
eq_(new_encoding, test_encoding)
- @testing.requires.psycopg2_or_pg8000_compatibility
+ @testing.requires.psycopg_or_pg8000_compatibility
@engines.close_open_connections
def test_autocommit_isolation_level(self):
c = testing.db.connect().execution_options(
assert result == [(1, "user", "lala")]
connection.execute(text("DROP TABLE speedy_users"))
- @testing.requires.psycopg2_or_pg8000_compatibility
+ @testing.requires.psycopg_or_pg8000_compatibility
def test_numeric_raise(self, connection):
stmt = text("select cast('hi' as char) as hi").columns(hi=Numeric)
assert_raises(exc.InvalidRequestError, connection.execute, stmt)
)
@testing.requires.psycopg2_compatibility
- def test_initial_transaction_state(self):
+ def test_initial_transaction_state_psycopg2(self):
from psycopg2.extensions import STATUS_IN_TRANSACTION
engine = engines.testing_engine()
with engine.connect() as conn:
ne_(conn.connection.status, STATUS_IN_TRANSACTION)
+
+ @testing.only_on("postgresql+psycopg")
+ def test_initial_transaction_state_psycopg(self):
+ from psycopg.pq import TransactionStatus
+
+ engine = engines.testing_engine()
+ with engine.connect() as conn:
+ ne_(
+ conn.connection.dbapi_connection.info.transaction_status,
+ TransactionStatus.INTRANS,
+ )
+
+
+class Psycopg3Test(fixtures.TestBase):
+ __only_on__ = ("postgresql+psycopg",)
+
+ def test_json_correctly_registered(self, testing_engine):
+ import json
+
+ def loads(value):
+ value = json.loads(value)
+ value["x"] = value["x"] + "_loads"
+ return value
+
+ def dumps(value):
+ value = dict(value)
+ value["x"] = "dumps_y"
+ return json.dumps(value)
+
+ engine = testing_engine(
+ options=dict(json_serializer=dumps, json_deserializer=loads)
+ )
+ engine2 = testing_engine(
+ options=dict(
+ json_serializer=json.dumps, json_deserializer=json.loads
+ )
+ )
+
+ s = select(cast({"key": "value", "x": "q"}, JSONB))
+ with engine.begin() as conn:
+ eq_(conn.scalar(s), {"key": "value", "x": "dumps_y_loads"})
+ with engine.begin() as conn:
+ eq_(conn.scalar(s), {"key": "value", "x": "dumps_y_loads"})
+ with engine2.begin() as conn:
+ eq_(conn.scalar(s), {"key": "value", "x": "q"})
+ with engine.begin() as conn:
+ eq_(conn.scalar(s), {"key": "value", "x": "dumps_y_loads"})
+
+ @testing.requires.hstore
+ def test_hstore_correctly_registered(self, testing_engine):
+ engine = testing_engine(options=dict(use_native_hstore=True))
+ engine2 = testing_engine(options=dict(use_native_hstore=False))
+
+ def rp(self, *a):
+ return lambda a: {"a": "b"}
+
+ with mock.patch.object(HSTORE, "result_processor", side_effect=rp):
+ s = select(cast({"key": "value", "x": "q"}, HSTORE))
+ with engine.begin() as conn:
+ eq_(conn.scalar(s), {"key": "value", "x": "q"})
+ with engine.begin() as conn:
+ eq_(conn.scalar(s), {"key": "value", "x": "q"})
+ with engine2.begin() as conn:
+ eq_(conn.scalar(s), {"a": "b"})
+ with engine.begin() as conn:
+ eq_(conn.scalar(s), {"key": "value", "x": "q"})
+
+ def test_get_dialect(self):
+ u = url.URL.create("postgresql://")
+ d = psycopg_dialect.PGDialect_psycopg.get_dialect_cls(u)
+ is_(d, psycopg_dialect.PGDialect_psycopg)
+ d = psycopg_dialect.PGDialect_psycopg.get_async_dialect_cls(u)
+ is_(d, psycopg_dialect.PGDialectAsync_psycopg)
+ d = psycopg_dialect.PGDialectAsync_psycopg.get_dialect_cls(u)
+ is_(d, psycopg_dialect.PGDialectAsync_psycopg)
+ d = psycopg_dialect.PGDialectAsync_psycopg.get_dialect_cls(u)
+ is_(d, psycopg_dialect.PGDialectAsync_psycopg)
+
+ def test_async_version(self):
+ e = create_engine("postgresql+psycopg_async://")
+ is_true(isinstance(e.dialect, psycopg_dialect.PGDialectAsync_psycopg))
from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import Integer
+from sqlalchemy import JSON
from sqlalchemy import literal
from sqlalchemy import literal_column
from sqlalchemy import MetaData
from sqlalchemy import tuple_
from sqlalchemy.dialects import postgresql
from sqlalchemy.dialects.postgresql import JSONB
+from sqlalchemy.sql.expression import type_coerce
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import AssertsExecutionResults
from sqlalchemy.testing.assertsql import DialectSQL
+class FunctionTypingTest(fixtures.TestBase, AssertsExecutionResults):
+ __only_on__ = "postgresql"
+ __backend__ = True
+
+ def test_count_star(self, connection):
+ eq_(connection.scalar(func.count("*")), 1)
+
+ def test_count_int(self, connection):
+ eq_(connection.scalar(func.count(1)), 1)
+
+
class InsertTest(fixtures.TestBase, AssertsExecutionResults):
__only_on__ = "postgresql"
],
)
+ def _strs_render_bind_casts(self, connection):
+
+ return (
+ connection.dialect._bind_typing_render_casts
+ and String().dialect_impl(connection.dialect).render_bind_cast
+ )
+
@testing.requires.pyformat_paramstyle
- def test_expression_pyformat(self):
+ def test_expression_pyformat(self, connection):
matchtable = self.tables.matchtable
- self.assert_compile(
- matchtable.c.title.match("somstr"),
- "matchtable.title @@ to_tsquery(%(title_1)s" ")",
- )
+
+ if self._strs_render_bind_casts(connection):
+ self.assert_compile(
+ matchtable.c.title.match("somstr"),
+ "matchtable.title @@ to_tsquery(%(title_1)s::VARCHAR(200))",
+ )
+ else:
+ self.assert_compile(
+ matchtable.c.title.match("somstr"),
+ "matchtable.title @@ to_tsquery(%(title_1)s)",
+ )
@testing.requires.format_paramstyle
- def test_expression_positional(self):
+ def test_expression_positional(self, connection):
matchtable = self.tables.matchtable
- self.assert_compile(
- matchtable.c.title.match("somstr"),
- # note we assume current tested DBAPIs use emulated setinputsizes
- # here, the cast is not strictly necessary
- "matchtable.title @@ to_tsquery(%s::VARCHAR(200))",
- )
+
+ if self._strs_render_bind_casts(connection):
+ self.assert_compile(
+ matchtable.c.title.match("somstr"),
+ "matchtable.title @@ to_tsquery(%s::VARCHAR(200))",
+ )
+ else:
+ self.assert_compile(
+ matchtable.c.title.match("somstr"),
+ "matchtable.title @@ to_tsquery(%s)",
+ )
def test_simple_match(self, connection):
matchtable = self.tables.matchtable
[(14, 1), (41, 2), (7, 3), (54, 4), (9, 5), (49, 6)],
)
- @testing.only_on(
- "postgresql+psycopg2",
- "I cannot get this to run at all on other drivers, "
- "even selecting from a table",
+ @testing.combinations(
+ (
+ type_coerce,
+ testing.fails("fails on all drivers"),
+ ),
+ (
+ cast,
+ testing.fails("fails on all drivers"),
+ ),
+ (
+ None,
+ testing.fails_on_everything_except(
+ ["postgresql+psycopg2"],
+ "I cannot get this to run at all on other drivers, "
+ "even selecting from a table",
+ ),
+ ),
+ argnames="cast_fn",
)
- def test_render_derived_quoting(self, connection):
+ def test_render_derived_quoting_text(self, connection, cast_fn):
+
+ value = (
+ '[{"CaseSensitive":1,"the % value":"foo"}, '
+ '{"CaseSensitive":"2","the % value":"bar"}]'
+ )
+
+ if cast_fn:
+ value = cast_fn(value, JSON)
+
fn = (
- func.json_to_recordset( # noqa
- '[{"CaseSensitive":1,"the % value":"foo"}, '
- '{"CaseSensitive":"2","the % value":"bar"}]'
+ func.json_to_recordset(value)
+ .table_valued(
+ column("CaseSensitive", Integer), column("the % value", String)
)
+ .render_derived(with_types=True)
+ )
+
+ stmt = select(fn.c.CaseSensitive, fn.c["the % value"])
+
+ eq_(connection.execute(stmt).all(), [(1, "foo"), (2, "bar")])
+
+ @testing.combinations(
+ (
+ type_coerce,
+ testing.fails("fails on all drivers"),
+ ),
+ (
+ cast,
+ testing.fails("fails on all drivers"),
+ ),
+ (
+ None,
+ testing.fails("Fails on all drivers"),
+ ),
+ argnames="cast_fn",
+ )
+ def test_render_derived_quoting_text_to_json(self, connection, cast_fn):
+
+ value = (
+ '[{"CaseSensitive":1,"the % value":"foo"}, '
+ '{"CaseSensitive":"2","the % value":"bar"}]'
+ )
+
+ if cast_fn:
+ value = cast_fn(value, JSON)
+
+ # why wont this work?!?!?
+ # should be exactly json_to_recordset(to_json('string'::text))
+ #
+ fn = (
+ func.json_to_recordset(func.to_json(value))
+ .table_valued(
+ column("CaseSensitive", Integer), column("the % value", String)
+ )
+ .render_derived(with_types=True)
+ )
+
+ stmt = select(fn.c.CaseSensitive, fn.c["the % value"])
+
+ eq_(connection.execute(stmt).all(), [(1, "foo"), (2, "bar")])
+
+ @testing.combinations(
+ (type_coerce,),
+ (cast,),
+ (None, testing.fails("Fails on all PG backends")),
+ argnames="cast_fn",
+ )
+ def test_render_derived_quoting_straight_json(self, connection, cast_fn):
+ # these all work
+
+ value = [
+ {"CaseSensitive": 1, "the % value": "foo"},
+ {"CaseSensitive": "2", "the % value": "bar"},
+ ]
+
+ if cast_fn:
+ value = cast_fn(value, JSON)
+
+ fn = (
+ func.json_to_recordset(value) # noqa
.table_valued(
column("CaseSensitive", Integer), column("the % value", String)
)
testing.combinations(
sqltypes.ARRAY,
postgresql.ARRAY,
- (_ArrayOfEnum, testing.only_on("postgresql+psycopg2")),
+ (_ArrayOfEnum, testing.requires.psycopg_compatibility),
argnames="array_cls",
)(fn)
)
@testing.fixture
def non_native_hstore_connection(self, testing_engine):
- local_engine = testing.requires.psycopg2_native_hstore.enabled
+ local_engine = testing.requires.native_hstore.enabled
if local_engine:
engine = testing_engine(options=dict(use_native_hstore=False))
)["1"]
eq_(connection.scalar(select(expr)), "3")
- @testing.requires.psycopg2_native_hstore
+ @testing.requires.native_hstore
def test_insert_native(self, connection):
self._test_insert(connection)
def test_insert_python(self, non_native_hstore_connection):
self._test_insert(non_native_hstore_connection)
- @testing.requires.psycopg2_native_hstore
+ @testing.requires.native_hstore
def test_criterion_native(self, connection):
self._fixture_data(connection)
self._test_criterion(connection)
def test_fixed_round_trip_python(self, non_native_hstore_connection):
self._test_fixed_round_trip(non_native_hstore_connection)
- @testing.requires.psycopg2_native_hstore
+ @testing.requires.native_hstore
def test_fixed_round_trip_native(self, connection):
self._test_fixed_round_trip(connection)
},
)
- @testing.requires.psycopg2_native_hstore
+ @testing.requires.native_hstore
def test_unicode_round_trip_python(self, non_native_hstore_connection):
self._test_unicode_round_trip(non_native_hstore_connection)
- @testing.requires.psycopg2_native_hstore
+ @testing.requires.native_hstore
def test_unicode_round_trip_native(self, connection):
self._test_unicode_round_trip(connection)
):
self._test_escaped_quotes_round_trip(non_native_hstore_connection)
- @testing.requires.psycopg2_native_hstore
+ @testing.requires.native_hstore
def test_escaped_quotes_round_trip_native(self, connection):
self._test_escaped_quotes_round_trip(connection)
class _RangeTypeRoundTrip(fixtures.TablesTest):
- __requires__ = "range_types", "psycopg2_compatibility"
+ __requires__ = "range_types", "psycopg_compatibility"
__backend__ = True
def extras(self):
# older psycopg2 versions.
if testing.against("postgresql+psycopg2cffi"):
from psycopg2cffi import extras
- else:
+ elif testing.against("postgresql+psycopg2"):
from psycopg2 import extras
+ elif testing.against("postgresql+psycopg"):
+ from psycopg.types.range import Range
+
+ class psycopg_extras:
+ def __getattr__(self, _):
+ return Range
+
+ extras = psycopg_extras()
+ else:
+ assert False, "Unknonw dialect"
return extras
@classmethod
)
+class TestGetDialect(fixtures.TestBase):
+ @testing.requires.sqlite
+ @testing.combinations(True, False, None)
+ def test_is_async_to_create_engine(self, is_async):
+ def get_dialect_cls(url):
+ url = url.set(drivername="sqlite")
+ return url.get_dialect()
+
+ global MockDialectGetDialect
+ MockDialectGetDialect = Mock()
+ MockDialectGetDialect.get_dialect_cls.side_effect = get_dialect_cls
+ MockDialectGetDialect.get_async_dialect_cls.side_effect = (
+ get_dialect_cls
+ )
+
+ registry.register("mockdialect", __name__, "MockDialectGetDialect")
+
+ from sqlalchemy.dialects import sqlite
+
+ kw = {}
+ if is_async is not None:
+ kw["_is_async"] = is_async
+ e = create_engine("mockdialect://", **kw)
+
+ eq_(e.dialect.name, "sqlite")
+ assert isinstance(e.dialect, sqlite.dialect)
+
+ if is_async:
+ eq_(
+ MockDialectGetDialect.mock_calls,
+ [
+ call.get_async_dialect_cls(url.make_url("mockdialect://")),
+ call.engine_created(e),
+ ],
+ )
+ else:
+ eq_(
+ MockDialectGetDialect.mock_calls,
+ [
+ call.get_dialect_cls(url.make_url("mockdialect://")),
+ call.engine_created(e),
+ ],
+ )
+ MockDialectGetDialect.reset_mock()
+ u = url.make_url("mockdialect://")
+ u.get_dialect(**kw)
+ if is_async:
+ eq_(
+ MockDialectGetDialect.mock_calls,
+ [call.get_async_dialect_cls(u)],
+ )
+ else:
+ eq_(
+ MockDialectGetDialect.mock_calls,
+ [call.get_dialect_cls(u)],
+ )
+
+
class MockDialect(DefaultDialect):
@classmethod
def dbapi(cls, **kw):
)
self.mock_connect = call(
- host="localhost", password="bar", user="foo", database="test"
+ host="localhost", password="bar", user="foo", dbname="test"
)
# monkeypatch disconnect checker
self.db.dialect.is_disconnect = lambda e, conn, cursor: isinstance(
"+aiosqlite",
"+aiomysql",
"+asyncmy",
+ "+psycopg",
],
"Buffers the result set and doesn't check for connection close",
)
eq_(
canary.mock_calls,
- [mock.call(sync_conn, mock.ANY, "select 1", (), mock.ANY, False)],
+ [
+ mock.call(
+ sync_conn, mock.ANY, "select 1", mock.ANY, mock.ANY, False
+ )
+ ],
)
@async_test
eq_(
canary.mock_calls,
- [mock.call(sync_conn, mock.ANY, "select 1", (), mock.ANY, False)],
+ [
+ mock.call(
+ sync_conn, mock.ANY, "select 1", mock.ANY, mock.ANY, False
+ )
+ ],
)
@async_test
def pyformat_paramstyle(self):
return only_on(
[
+ "postgresql+psycopg",
"postgresql+psycopg2",
"postgresql+psycopg2cffi",
"mysql+mysqlconnector",
@property
def infinity_floats(self):
return fails_on_everything_except(
- "sqlite", "postgresql+psycopg2", "postgresql+asyncpg"
+ "sqlite",
+ "postgresql+psycopg2",
+ "postgresql+asyncpg",
+ "postgresql+psycopg",
) + skip_if(
"postgresql+pg8000", "seems to work on pg14 only, not earlier?"
)
@property
def range_types(self):
def check_range_types(config):
- if not against(
- config, ["postgresql+psycopg2", "postgresql+psycopg2cffi"]
- ):
+ if not self.psycopg_compatibility.enabled:
return False
try:
with config.db.connect() as conn:
)
@property
- def psycopg2_native_hstore(self):
- return self.psycopg2_compatibility
+ def native_hstore(self):
+ return self.psycopg_compatibility
@property
def psycopg2_compatibility(self):
return only_on(["postgresql+psycopg2", "postgresql+psycopg2cffi"])
@property
- def psycopg2_or_pg8000_compatibility(self):
+ def psycopg_compatibility(self):
return only_on(
[
"postgresql+psycopg2",
"postgresql+psycopg2cffi",
- "postgresql+pg8000",
+ "postgresql+psycopg",
]
)
+ @property
+ def psycopg_or_pg8000_compatibility(self):
+ return only_on([self.psycopg_compatibility, "postgresql+pg8000"])
+
@property
def percent_schema_names(self):
return skip_if(
def reflect_tables_no_columns(self):
# so far sqlite, mariadb, mysql don't support this
return only_on(["postgresql"])
+
+ @property
+ def json_deserializer_binary(self):
+ "indicates if the json_deserializer function is called with bytes"
+ return only_on(["postgresql+psycopg"])
@testing.combinations(
(True, "omit_alias"), (False, "with_alias"), id_="ai", argnames="omit"
)
- @testing.provide_metadata
@testing.skip_if("mysql < 8")
- def test_duplicate_values_accepted(self, native, omit):
+ def test_duplicate_values_accepted(
+ self, metadata, connection, native, omit
+ ):
foo_enum = pep435_enum("foo_enum")
foo_enum("one", 1, "two")
foo_enum("three", 3, "four")
tbl = sa.Table(
"foo_table",
- self.metadata,
+ metadata,
sa.Column("id", sa.Integer),
sa.Column(
"data",
)
t = sa.table("foo_table", sa.column("id"), sa.column("data"))
- self.metadata.create_all(testing.db)
+ metadata.create_all(connection)
if omit:
with expect_raises(
(
exc.DBAPIError,
)
):
- with testing.db.begin() as conn:
- conn.execute(
- t.insert(),
- [
- {"id": 1, "data": "four"},
- {"id": 2, "data": "three"},
- ],
- )
- else:
- with testing.db.begin() as conn:
- conn.execute(
+ connection.execute(
t.insert(),
- [{"id": 1, "data": "four"}, {"id": 2, "data": "three"}],
+ [
+ {"id": 1, "data": "four"},
+ {"id": 2, "data": "three"},
+ ],
)
+ else:
+ connection.execute(
+ t.insert(),
+ [{"id": 1, "data": "four"}, {"id": 2, "data": "three"}],
+ )
- eq_(
- conn.execute(t.select().order_by(t.c.id)).fetchall(),
- [(1, "four"), (2, "three")],
- )
- eq_(
- conn.execute(tbl.select().order_by(tbl.c.id)).fetchall(),
- [(1, foo_enum.three), (2, foo_enum.three)],
- )
+ eq_(
+ connection.execute(t.select().order_by(t.c.id)).fetchall(),
+ [(1, "four"), (2, "three")],
+ )
+ eq_(
+ connection.execute(tbl.select().order_by(tbl.c.id)).fetchall(),
+ [(1, foo_enum.three), (2, foo_enum.three)],
+ )
MyPickleType = None
postgresql: .[postgresql]
postgresql: .[postgresql_asyncpg]; python_version >= '3'
postgresql: .[postgresql_pg8000]; python_version >= '3'
+ postgresql: .[postgresql_psycopg]; python_version >= '3'
mysql: .[mysql]
mysql: .[pymysql]
dbapimain-postgresql: git+https://github.com/psycopg/psycopg2.git#egg=psycopg2
dbapimain-postgresql: git+https://github.com/MagicStack/asyncpg.git#egg=asyncpg
dbapimain-postgresql: git+https://github.com/tlocke/pg8000.git#egg=pg8000
+ dbapimain-postgresql: git+https://github.com/psycopg/psycopg.git#egg=psycopg&subdirectory=psycopg
+ dbapimain-postgresql: git+https://github.com/psycopg/psycopg.git#egg=psycopg-c&subdirectory=psycopg_c
dbapimain-mysql: git+https://github.com/PyMySQL/mysqlclient-python.git#egg=mysqlclient
dbapimain-mysql: git+https://github.com/PyMySQL/PyMySQL.git#egg=pymysql
sqlite: SQLITE={env:TOX_SQLITE:--db sqlite}
sqlite_file: SQLITE={env:TOX_SQLITE_FILE:--db sqlite_file}
- py3{,5,6,7,8,9,10,11}-sqlite: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver aiosqlite}
- py3{,5,6,7,8,9}-sqlite_file: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver aiosqlite --dbdriver pysqlcipher}
+ py3{,7,8,9,10,11}-sqlite: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver aiosqlite}
+ py3{,7,8,9}-sqlite_file: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver aiosqlite --dbdriver pysqlcipher}
# omit pysqlcipher for Python 3.10
py3{,10,11}-sqlite_file: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver aiosqlite}
postgresql: POSTGRESQL={env:TOX_POSTGRESQL:--db postgresql}
- py2{,7}-postgresql: POSTGRESQL={env:TOX_POSTGRESQL_PY2K:{env:TOX_POSTGRESQL:--db postgresql}}
- py3{,5,6,7,8,9,10,11}-postgresql: EXTRA_PG_DRIVERS={env:EXTRA_PG_DRIVERS:--dbdriver psycopg2 --dbdriver asyncpg --dbdriver pg8000}
+ py3{,7,8,9,10,11}-postgresql: EXTRA_PG_DRIVERS={env:EXTRA_PG_DRIVERS:--dbdriver psycopg2 --dbdriver asyncpg --dbdriver pg8000 --dbdriver psycopg --dbdriver psycopg_async}
mysql: MYSQL={env:TOX_MYSQL:--db mysql}
py2{,7}-mysql: MYSQL={env:TOX_MYSQL_PY2K:{env:TOX_MYSQL:--db mysql}}