]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Added support for ``psycopg`` dialect.
authorFederico Caselli <cfederico87@gmail.com>
Tue, 14 Sep 2021 21:38:00 +0000 (23:38 +0200)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 26 Nov 2021 15:14:44 +0000 (10:14 -0500)
Both sync and async versions are supported.

Fixes: #6842
Change-Id: I57751c5028acebfc6f9c43572562405453a2f2a4

32 files changed:
doc/build/changelog/migration_20.rst
doc/build/changelog/unreleased_14/6842.rst [new file with mode: 0644]
doc/build/dialects/postgresql.rst
lib/sqlalchemy/dialects/mysql/asyncmy.py
lib/sqlalchemy/dialects/postgresql/__init__.py
lib/sqlalchemy/dialects/postgresql/_psycopg_common.py [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/psycopg.py [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/psycopg2.py
lib/sqlalchemy/engine/create.py
lib/sqlalchemy/engine/cursor.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/interfaces.py
lib/sqlalchemy/engine/url.py
lib/sqlalchemy/ext/asyncio/engine.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/testing/config.py
lib/sqlalchemy/testing/plugin/plugin_base.py
lib/sqlalchemy/testing/requirements.py
lib/sqlalchemy/testing/suite/test_results.py
lib/sqlalchemy/testing/suite/test_types.py
setup.cfg
test/dialect/postgresql/test_dialect.py
test/dialect/postgresql/test_query.py
test/dialect/postgresql/test_types.py
test/engine/test_parseconnect.py
test/engine/test_reconnect.py
test/ext/asyncio/test_engine_py3k.py
test/requirements.py
test/sql/test_types.py
tox.ini

index b75cefb31ee976f538d6caf1ded3495bc08a1a16..1c01698886154e8834e4f21149a984ef9f685926 100644 (file)
@@ -77,6 +77,31 @@ New Features and Improvements
 This section covers new features and improvements in SQLAlchemy 2.0 which
 are not otherwise part of the major 1.4->2.0 migration path.
 
+.. _ticket_6842:
+
+Dialect support for psycopg 3 (a.k.a. "psycopg")
+-------------------------------------------------
+
+Added dialect support for the `psycopg 3 <https://pypi.org/project/psycopg/>`_
+DBAPI, which despite the number "3" now goes by the package name ``psycopg``,
+superseding the previous ``psycopg2`` package that for the time being remains
+SQLAlchemy's "default" driver for the ``postgresql`` dialects. ``psycopg`` is a
+completely reworked and modernized database adapter for PostgreSQL which
+supports concepts such as prepared statements as well as Python asyncio.
+
+``psycopg`` is the first DBAPI supported by SQLAlchemy which provides
+both a pep-249 synchronous API as well as an asyncio driver.  The same
+``psycopg`` database URL may be used with the :func:`_sa.create_engine`
+and :func:`_asyncio.create_async_engine` engine-creation functions, and the
+corresponding sync or asyncio version of the dialect will be selected
+automatically.
+
+.. seealso::
+
+    :ref:`postgresql_psycopg`
+
+
+
 Behavioral Changes
 ==================
 
diff --git a/doc/build/changelog/unreleased_14/6842.rst b/doc/build/changelog/unreleased_14/6842.rst
new file mode 100644 (file)
index 0000000..43b3841
--- /dev/null
@@ -0,0 +1,14 @@
+.. change::
+    :tags: postgresql, dialect
+    :tickets: 6842
+
+    Added support for ``psycopg`` dialect supporting both sync and async
+    execution. This dialect is available under the ``postgresql+psycopg`` name
+    for both the :func:`_sa.create_engine` and
+    :func:`_asyncio.create_async_engine` engine-creation functions.
+
+    .. seealso::
+
+        :ref:`ticket_6842`
+
+        :ref:`postgresql_psycopg`
\ No newline at end of file
index 958f8e06026a1c9df5c49c7ba3fdd2fa860b3a9b..d3c9928c71b1abf54667d30fa2bcb93aab86cc7c 100644 (file)
@@ -193,6 +193,13 @@ psycopg2
 
 .. automodule:: sqlalchemy.dialects.postgresql.psycopg2
 
+.. _postgresql_psycopg:
+
+psycopg
+--------
+
+.. automodule:: sqlalchemy.dialects.postgresql.psycopg
+
 pg8000
 ------
 
index b595714603619ac922b143e9a4b82eca753c0a97..8092e99ffbc9b44fa79688d06ece232a32cc5b08 100644 (file)
@@ -175,7 +175,7 @@ class AsyncAdapt_asyncmy_ss_cursor(AsyncAdapt_asyncmy_cursor):
 
 class AsyncAdapt_asyncmy_connection(AdaptedConnection):
     await_ = staticmethod(await_only)
-    __slots__ = ("dbapi", "_connection", "_execute_mutex")
+    __slots__ = ("dbapi", "_execute_mutex")
 
     def __init__(self, dbapi, connection):
         self.dbapi = dbapi
index 08b05dc748f487b8afe5d25af100d8103c12879b..b0af5395e7efaacfc79111ebc6706eaa8cbe8a6f 100644 (file)
@@ -4,9 +4,12 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
+from types import ModuleType
+
 from . import asyncpg  # noqa
 from . import base
 from . import pg8000  # noqa
+from . import psycopg  # noqa
 from . import psycopg2  # noqa
 from . import psycopg2cffi  # noqa
 from .array import All
@@ -58,6 +61,11 @@ from .ranges import TSRANGE
 from .ranges import TSTZRANGE
 from ...util import compat
 
+# Alias psycopg also as psycopg_asnyc
+psycopg_async = type(
+    "psycopg_asnyc", (ModuleType,), {"dialect": psycopg.dialect_async}
+)
+
 base.dialect = dialect = psycopg2.dialect
 
 
diff --git a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py
new file mode 100644 (file)
index 0000000..d82d5f0
--- /dev/null
@@ -0,0 +1,188 @@
+import decimal
+
+from .array import ARRAY as PGARRAY
+from .base import _DECIMAL_TYPES
+from .base import _FLOAT_TYPES
+from .base import _INT_TYPES
+from .base import PGDialect
+from .base import PGExecutionContext
+from .base import UUID
+from .hstore import HSTORE
+from ... import exc
+from ... import processors
+from ... import types as sqltypes
+from ... import util
+
+_server_side_id = util.counter()
+
+
+class _PsycopgNumeric(sqltypes.Numeric):
+    def bind_processor(self, dialect):
+        return None
+
+    def result_processor(self, dialect, coltype):
+        if self.asdecimal:
+            if coltype in _FLOAT_TYPES:
+                return processors.to_decimal_processor_factory(
+                    decimal.Decimal, self._effective_decimal_return_scale
+                )
+            elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
+                # psycopg returns Decimal natively for 1700
+                return None
+            else:
+                raise exc.InvalidRequestError(
+                    "Unknown PG numeric type: %d" % coltype
+                )
+        else:
+            if coltype in _FLOAT_TYPES:
+                # psycopg returns float natively for 701
+                return None
+            elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
+                return processors.to_float
+            else:
+                raise exc.InvalidRequestError(
+                    "Unknown PG numeric type: %d" % coltype
+                )
+
+
+class _PsycopgHStore(HSTORE):
+    def bind_processor(self, dialect):
+        if dialect._has_native_hstore:
+            return None
+        else:
+            return super(_PsycopgHStore, self).bind_processor(dialect)
+
+    def result_processor(self, dialect, coltype):
+        if dialect._has_native_hstore:
+            return None
+        else:
+            return super(_PsycopgHStore, self).result_processor(
+                dialect, coltype
+            )
+
+
+class _PsycopgUUID(UUID):
+    def bind_processor(self, dialect):
+        return None
+
+    def result_processor(self, dialect, coltype):
+        if not self.as_uuid and dialect.use_native_uuid:
+
+            def process(value):
+                if value is not None:
+                    value = str(value)
+                return value
+
+            return process
+
+
+class _PsycopgARRAY(PGARRAY):
+    render_bind_cast = True
+
+
+class _PGExecutionContext_common_psycopg(PGExecutionContext):
+    def create_server_side_cursor(self):
+        # use server-side cursors:
+        # psycopg
+        # https://www.psycopg.org/psycopg3/docs/advanced/cursors.html#server-side-cursors
+        # psycopg2
+        # https://www.psycopg.org/docs/usage.html#server-side-cursors
+        ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:])
+        return self._dbapi_connection.cursor(ident)
+
+
+class _PGDialect_common_psycopg(PGDialect):
+    supports_statement_cache = True
+    supports_server_side_cursors = True
+
+    default_paramstyle = "pyformat"
+
+    _has_native_hstore = True
+
+    colspecs = util.update_copy(
+        PGDialect.colspecs,
+        {
+            sqltypes.Numeric: _PsycopgNumeric,
+            HSTORE: _PsycopgHStore,
+            UUID: _PsycopgUUID,
+            sqltypes.ARRAY: _PsycopgARRAY,
+        },
+    )
+
+    def __init__(
+        self,
+        client_encoding=None,
+        use_native_hstore=True,
+        use_native_uuid=True,
+        **kwargs
+    ):
+        PGDialect.__init__(self, **kwargs)
+        if not use_native_hstore:
+            self._has_native_hstore = False
+        self.use_native_hstore = use_native_hstore
+        self.use_native_uuid = use_native_uuid
+        self.client_encoding = client_encoding
+
+    def create_connect_args(self, url):
+        opts = url.translate_connect_args(username="user", database="dbname")
+
+        is_multihost = False
+        if "host" in url.query:
+            is_multihost = isinstance(url.query["host"], (list, tuple))
+
+        if opts:
+            if "port" in opts:
+                opts["port"] = int(opts["port"])
+            opts.update(url.query)
+            if is_multihost:
+                opts["host"] = ",".join(url.query["host"])
+            # send individual dbname, user, password, host, port
+            # parameters to psycopg2.connect()
+            return ([], opts)
+        elif url.query:
+            # any other connection arguments, pass directly
+            opts.update(url.query)
+            if is_multihost:
+                opts["host"] = ",".join(url.query["host"])
+            return ([], opts)
+        else:
+            # no connection arguments whatsoever; psycopg2.connect()
+            # requires that "dsn" be present as a blank string.
+            return ([""], opts)
+
+    def get_isolation_level_values(self, dbapi_conn):
+        return (
+            "AUTOCOMMIT",
+            "READ COMMITTED",
+            "READ UNCOMMITTED",
+            "REPEATABLE READ",
+            "SERIALIZABLE",
+        )
+
+    def set_deferrable(self, connection, value):
+        connection.deferrable = value
+
+    def get_deferrable(self, connection):
+        return connection.deferrable
+
+    def _do_autocommit(self, connection, value):
+        connection.autocommit = value
+
+    def do_ping(self, dbapi_connection):
+        cursor = None
+        try:
+            self._do_autocommit(dbapi_connection, True)
+            cursor = dbapi_connection.cursor()
+            try:
+                cursor.execute(self._dialect_specific_select_one)
+            finally:
+                cursor.close()
+                if not dbapi_connection.closed:
+                    self._do_autocommit(dbapi_connection, False)
+        except self.dbapi.Error as err:
+            if self.is_disconnect(err, dbapi_connection, cursor):
+                return False
+            else:
+                raise
+        else:
+            return True
index 5ef9df800147889ef35bb9cfd0f29130b3b44ea6..1fdb46b6f2e6ca595c7b185d7bb32b3da29cefa1 100644 (file)
@@ -553,7 +553,6 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor):
 class AsyncAdapt_asyncpg_connection(AdaptedConnection):
     __slots__ = (
         "dbapi",
-        "_connection",
         "isolation_level",
         "_isolation_setting",
         "readonly",
index 008398865324cdb8850f2a56e978d138864d4c9d..d1d881dc3d48fb01c2196faebe2eb4abfb9cd833 100644 (file)
@@ -1701,7 +1701,7 @@ class UUID(sqltypes.TypeEngine):
     or as Python uuid objects.
 
     The UUID type is currently known to work within the prominent DBAPI
-    drivers supported by SQLAlchemy including psycopg2, pg8000 and
+    drivers supported by SQLAlchemy including psycopg, psycopg2, pg8000 and
     asyncpg. Support for other DBAPI drivers may be incomplete or non-present.
 
     """
@@ -1992,6 +1992,12 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
 
             self.connection.execute(DropEnumType(enum))
 
+    def get_dbapi_type(self, dbapi):
+        """dont return dbapi.STRING for ENUM in PostgreSQL, since that's
+        a different type"""
+
+        return None
+
     def _check_for_name_in_memos(self, checkfirst, kw):
         """Look in the 'ddl runner' for 'memos', then
         note our name in that collection.
diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg.py b/lib/sqlalchemy/dialects/postgresql/psycopg.py
new file mode 100644 (file)
index 0000000..c2017c9
--- /dev/null
@@ -0,0 +1,641 @@
+# postgresql/psycopg2.py
+# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+r"""
+.. dialect:: postgresql+psycopg
+    :name: psycopg (a.k.a. psycopg 3)
+    :dbapi: psycopg
+    :connectstring: postgresql+psycopg://user:password@host:port/dbname[?key=value&key=value...]
+    :url: https://pypi.org/project/psycopg/
+
+``psycopg`` is the package and module name for version 3 of the ``psycopg``
+database driver, formerly known as ``psycopg2``.  This driver is different
+enough from its ``psycopg2`` predecessor that SQLAlchemy supports it
+via a totally separate dialect; support for ``psycopg2`` is expected to remain
+for as long as that package continues to function for modern Python versions,
+and also remains the default dialect for the ``postgresql://`` dialect
+series.
+
+The SQLAlchemy ``psycopg`` dialect provides both a sync and an async
+implementation under the same dialect name. The proper version is
+selected depending on how the engine is created:
+
+* calling :func:`_sa.create_engine` with ``postgresql+psycopg://...`` will
+  automatically select the sync version, e.g.::
+
+    from sqlalchemy import create_engine
+    sync_engine = create_engine("postgresql+psycopg://scott:tiger@localhost/test")
+
+* calling :func:`_asyncio.create_async_engine` with
+  ``postgresql+psycopg://...`` will automatically select the async version,
+  e.g.::
+
+    from sqlalchemy.ext.asyncio import create_async_engine
+    asyncio_engine = create_async_engine("postgresql+psycopg://scott:tiger@localhost/test")
+
+The asyncio version of the dialect may also be specified explicitly using the
+``psycopg_async`` suffix, as::
+
+    from sqlalchemy.ext.asyncio import create_async_engine
+    asyncio_engine = create_async_engine("postgresql+psycopg_async://scott:tiger@localhost/test")
+
+The ``psycopg`` dialect has the same API features as that of ``psycopg2``,
+with the exeption of the "fast executemany" helpers.   The "fast executemany"
+helpers are expected to be generalized and ported to ``psycopg`` before the final
+release of SQLAlchemy 2.0, however.
+
+
+.. seealso::
+
+    :ref:`postgresql_psycopg2` - The SQLAlchemy ``psycopg``
+    dialect shares most of its behavior with the ``psycopg2`` dialect.
+    Further documentation is available there.
+
+"""  # noqa
+import logging
+import re
+
+from ._psycopg_common import _PGDialect_common_psycopg
+from ._psycopg_common import _PGExecutionContext_common_psycopg
+from ._psycopg_common import _PsycopgUUID
+from .base import INTERVAL
+from .base import PGCompiler
+from .base import PGIdentifierPreparer
+from .base import UUID
+from .json import JSON
+from .json import JSONB
+from .json import JSONPathType
+from ... import pool
+from ... import types as sqltypes
+from ... import util
+from ...engine import AdaptedConnection
+from ...util.concurrency import await_fallback
+from ...util.concurrency import await_only
+
+logger = logging.getLogger("sqlalchemy.dialects.postgresql")
+
+
+class _PGString(sqltypes.String):
+    render_bind_cast = True
+
+
+class _PGJSON(JSON):
+    render_bind_cast = True
+
+    def bind_processor(self, dialect):
+        return self._make_bind_processor(None, dialect._psycopg_Json)
+
+    def result_processor(self, dialect, coltype):
+        return None
+
+
+class _PGJSONB(JSONB):
+    render_bind_cast = True
+
+    def bind_processor(self, dialect):
+        return self._make_bind_processor(None, dialect._psycopg_Jsonb)
+
+    def result_processor(self, dialect, coltype):
+        return None
+
+
+class _PGJSONIntIndexType(sqltypes.JSON.JSONIntIndexType):
+    __visit_name__ = "json_int_index"
+
+    render_bind_cast = True
+
+
+class _PGJSONStrIndexType(sqltypes.JSON.JSONStrIndexType):
+    __visit_name__ = "json_str_index"
+
+    render_bind_cast = True
+
+
+class _PGJSONPathType(JSONPathType):
+    pass
+
+
+class _PGUUID(_PsycopgUUID):
+    render_bind_cast = True
+
+
+class _PGInterval(INTERVAL):
+    render_bind_cast = True
+
+
+class _PGTimeStamp(sqltypes.DateTime):
+    render_bind_cast = True
+
+
+class _PGDate(sqltypes.Date):
+    render_bind_cast = True
+
+
+class _PGTime(sqltypes.Time):
+    render_bind_cast = True
+
+
+class _PGInteger(sqltypes.Integer):
+    render_bind_cast = True
+
+
+class _PGSmallInteger(sqltypes.SmallInteger):
+    render_bind_cast = True
+
+
+class _PGNullType(sqltypes.NullType):
+    render_bind_cast = True
+
+
+class _PGBigInteger(sqltypes.BigInteger):
+    render_bind_cast = True
+
+
+class _PGBoolean(sqltypes.Boolean):
+    render_bind_cast = True
+
+
+class PGExecutionContext_psycopg(_PGExecutionContext_common_psycopg):
+    pass
+
+
+class PGCompiler_psycopg(PGCompiler):
+    pass
+
+
+class PGIdentifierPreparer_psycopg(PGIdentifierPreparer):
+    pass
+
+
+def _log_notices(diagnostic):
+    logger.info("%s: %s", diagnostic.severity, diagnostic.message_primary)
+
+
+class PGDialect_psycopg(_PGDialect_common_psycopg):
+    driver = "psycopg"
+
+    supports_statement_cache = True
+    supports_server_side_cursors = True
+    default_paramstyle = "pyformat"
+    supports_sane_multi_rowcount = True
+
+    execution_ctx_cls = PGExecutionContext_psycopg
+    statement_compiler = PGCompiler_psycopg
+    preparer = PGIdentifierPreparer_psycopg
+    psycopg_version = (0, 0)
+
+    _has_native_hstore = True
+
+    colspecs = util.update_copy(
+        _PGDialect_common_psycopg.colspecs,
+        {
+            sqltypes.String: _PGString,
+            JSON: _PGJSON,
+            sqltypes.JSON: _PGJSON,
+            JSONB: _PGJSONB,
+            sqltypes.JSON.JSONPathType: _PGJSONPathType,
+            sqltypes.JSON.JSONIntIndexType: _PGJSONIntIndexType,
+            sqltypes.JSON.JSONStrIndexType: _PGJSONStrIndexType,
+            UUID: _PGUUID,
+            sqltypes.Interval: _PGInterval,
+            INTERVAL: _PGInterval,
+            sqltypes.Date: _PGDate,
+            sqltypes.DateTime: _PGTimeStamp,
+            sqltypes.Time: _PGTime,
+            sqltypes.Integer: _PGInteger,
+            sqltypes.SmallInteger: _PGSmallInteger,
+            sqltypes.BigInteger: _PGBigInteger,
+        },
+    )
+
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+
+        if self.dbapi:
+            m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__)
+            if m:
+                self.psycopg_version = tuple(
+                    int(x) for x in m.group(1, 2, 3) if x is not None
+                )
+
+            if self.psycopg_version < (3, 0, 2):
+                raise ImportError(
+                    "psycopg version 3.0.2 or higher is required."
+                )
+
+            from psycopg.adapt import AdaptersMap
+
+            self._psycopg_adapters_map = adapters_map = AdaptersMap(
+                self.dbapi.adapters
+            )
+
+            if self._json_deserializer:
+                from psycopg.types.json import set_json_loads
+
+                set_json_loads(self._json_deserializer, adapters_map)
+
+            if self._json_serializer:
+                from psycopg.types.json import set_json_dumps
+
+                set_json_dumps(self._json_serializer, adapters_map)
+
+    def create_connect_args(self, url):
+        # see https://github.com/psycopg/psycopg/issues/83
+        cargs, cparams = super().create_connect_args(url)
+
+        cparams["context"] = self._psycopg_adapters_map
+        if self.client_encoding is not None:
+            cparams["client_encoding"] = self.client_encoding
+        return cargs, cparams
+
+    def _type_info_fetch(self, connection, name):
+        from psycopg.types import TypeInfo
+
+        return TypeInfo.fetch(connection.connection, name)
+
+    def initialize(self, connection):
+        super().initialize(connection)
+
+        # PGDialect.initialize() checks server version for <= 8.2 and sets
+        # this flag to False if so
+        if not self.full_returning:
+            self.insert_executemany_returning = False
+
+        # HSTORE can't be registered until we have a connection so that
+        # we can look up its OID, so we set up this adapter in
+        # initialize()
+        if self.use_native_hstore:
+            info = self._type_info_fetch(connection, "hstore")
+            self._has_native_hstore = info is not None
+            if self._has_native_hstore:
+                from psycopg.types.hstore import register_hstore
+
+                # register the adapter for connections made subsequent to
+                # this one
+                register_hstore(info, self._psycopg_adapters_map)
+
+                # register the adapter for this connection
+                register_hstore(info, connection.connection)
+
+    @classmethod
+    def dbapi(cls):
+        import psycopg
+
+        return psycopg
+
+    @classmethod
+    def get_async_dialect_cls(cls, url):
+        return PGDialectAsync_psycopg
+
+    @util.memoized_property
+    def _isolation_lookup(self):
+        return {
+            "READ COMMITTED": self.dbapi.IsolationLevel.READ_COMMITTED,
+            "READ UNCOMMITTED": self.dbapi.IsolationLevel.READ_UNCOMMITTED,
+            "REPEATABLE READ": self.dbapi.IsolationLevel.REPEATABLE_READ,
+            "SERIALIZABLE": self.dbapi.IsolationLevel.SERIALIZABLE,
+        }
+
+    @util.memoized_property
+    def _psycopg_Json(self):
+        from psycopg.types import json
+
+        return json.Json
+
+    @util.memoized_property
+    def _psycopg_Jsonb(self):
+        from psycopg.types import json
+
+        return json.Jsonb
+
+    @util.memoized_property
+    def _psycopg_TransactionStatus(self):
+        from psycopg.pq import TransactionStatus
+
+        return TransactionStatus
+
+    def _do_isolation_level(self, connection, autocommit, isolation_level):
+        connection.autocommit = autocommit
+        connection.isolation_level = isolation_level
+
+    def get_isolation_level(self, dbapi_connection):
+        if hasattr(dbapi_connection, "dbapi_connection"):
+            dbapi_connection = dbapi_connection.dbapi_connection
+
+        status_before = dbapi_connection.info.transaction_status
+        value = super().get_isolation_level(dbapi_connection)
+
+        # don't rely on psycopg providing enum symbols, compare with
+        # eq/ne
+        if status_before == self._psycopg_TransactionStatus.IDLE:
+            dbapi_connection.rollback()
+        return value
+
+    def set_isolation_level(self, connection, level):
+        connection = getattr(connection, "dbapi_connection", connection)
+        if level == "AUTOCOMMIT":
+            self._do_isolation_level(
+                connection, autocommit=True, isolation_level=None
+            )
+        else:
+            self._do_isolation_level(
+                connection,
+                autocommit=False,
+                isolation_level=self._isolation_lookup[level],
+            )
+
+    def set_readonly(self, connection, value):
+        connection.read_only = value
+
+    def get_readonly(self, connection):
+        return connection.read_only
+
+    def on_connect(self):
+        def notices(conn):
+            conn.add_notice_handler(_log_notices)
+
+        fns = [notices]
+
+        if self.isolation_level is not None:
+
+            def on_connect(conn):
+                self.set_isolation_level(conn, self.isolation_level)
+
+            fns.append(on_connect)
+
+        # fns always has the notices function
+        def on_connect(conn):
+            for fn in fns:
+                fn(conn)
+
+        return on_connect
+
+    def is_disconnect(self, e, connection, cursor):
+        if isinstance(e, self.dbapi.Error) and connection is not None:
+            if connection.closed or connection.broken:
+                return True
+        return False
+
+    def _do_prepared_twophase(self, connection, command, recover=False):
+        dbapi_conn = connection.connection.dbapi_connection
+        if (
+            recover
+            # don't rely on psycopg providing enum symbols, compare with
+            # eq/ne
+            or dbapi_conn.info.transaction_status
+            != self._psycopg_TransactionStatus.IDLE
+        ):
+            dbapi_conn.rollback()
+        before = dbapi_conn.autocommit
+        try:
+            self._do_autocommit(dbapi_conn, True)
+            dbapi_conn.execute(command)
+        finally:
+            self._do_autocommit(dbapi_conn, before)
+
+    def do_rollback_twophase(
+        self, connection, xid, is_prepared=True, recover=False
+    ):
+        if is_prepared:
+            self._do_prepared_twophase(
+                connection, f"ROLLBACK PREPARED '{xid}'", recover=recover
+            )
+        else:
+            self.do_rollback(connection.connection)
+
+    def do_commit_twophase(
+        self, connection, xid, is_prepared=True, recover=False
+    ):
+        if is_prepared:
+            self._do_prepared_twophase(
+                connection, f"COMMIT PREPARED '{xid}'", recover=recover
+            )
+        else:
+            self.do_commit(connection.connection)
+
+
+class AsyncAdapt_psycopg_cursor:
+    __slots__ = ("_cursor", "await_", "_rows")
+
+    _psycopg_ExecStatus = None
+
+    def __init__(self, cursor, await_) -> None:
+        self._cursor = cursor
+        self.await_ = await_
+        self._rows = []
+
+    def __getattr__(self, name):
+        return getattr(self._cursor, name)
+
+    @property
+    def arraysize(self):
+        return self._cursor.arraysize
+
+    @arraysize.setter
+    def arraysize(self, value):
+        self._cursor.arraysize = value
+
+    def close(self):
+        self._rows.clear()
+        # Normal cursor just call _close() in a non-sync way.
+        self._cursor._close()
+
+    def execute(self, query, params=None, **kw):
+        result = self.await_(self._cursor.execute(query, params, **kw))
+        # sqlalchemy result is not async, so need to pull all rows here
+        res = self._cursor.pgresult
+
+        # don't rely on psycopg providing enum symbols, compare with
+        # eq/ne
+        if res and res.status == self._psycopg_ExecStatus.TUPLES_OK:
+            rows = self.await_(self._cursor.fetchall())
+            if not isinstance(rows, list):
+                self._rows = list(rows)
+            else:
+                self._rows = rows
+        return result
+
+    def executemany(self, query, params_seq):
+        return self.await_(self._cursor.executemany(query, params_seq))
+
+    def __iter__(self):
+        # TODO: try to avoid pop(0) on a list
+        while self._rows:
+            yield self._rows.pop(0)
+
+    def fetchone(self):
+        if self._rows:
+            # TODO: try to avoid pop(0) on a list
+            return self._rows.pop(0)
+        else:
+            return None
+
+    def fetchmany(self, size=None):
+        if size is None:
+            size = self._cursor.arraysize
+
+        retval = self._rows[0:size]
+        self._rows = self._rows[size:]
+        return retval
+
+    def fetchall(self):
+        retval = self._rows
+        self._rows = []
+        return retval
+
+
+class AsyncAdapt_psycopg_ss_cursor(AsyncAdapt_psycopg_cursor):
+    def execute(self, query, params=None, **kw):
+        self.await_(self._cursor.execute(query, params, **kw))
+        return self
+
+    def close(self):
+        self.await_(self._cursor.close())
+
+    def fetchone(self):
+        return self.await_(self._cursor.fetchone())
+
+    def fetchmany(self, size=0):
+        return self.await_(self._cursor.fetchmany(size))
+
+    def fetchall(self):
+        return self.await_(self._cursor.fetchall())
+
+    def __iter__(self):
+        iterator = self._cursor.__aiter__()
+        while True:
+            try:
+                yield self.await_(iterator.__anext__())
+            except StopAsyncIteration:
+                break
+
+
+class AsyncAdapt_psycopg_connection(AdaptedConnection):
+    __slots__ = ()
+    await_ = staticmethod(await_only)
+
+    def __init__(self, connection) -> None:
+        self._connection = connection
+
+    def __getattr__(self, name):
+        return getattr(self._connection, name)
+
+    def execute(self, query, params=None, **kw):
+        cursor = self.await_(self._connection.execute(query, params, **kw))
+        return AsyncAdapt_psycopg_cursor(cursor, self.await_)
+
+    def cursor(self, *args, **kw):
+        cursor = self._connection.cursor(*args, **kw)
+        if hasattr(cursor, "name"):
+            return AsyncAdapt_psycopg_ss_cursor(cursor, self.await_)
+        else:
+            return AsyncAdapt_psycopg_cursor(cursor, self.await_)
+
+    def commit(self):
+        self.await_(self._connection.commit())
+
+    def rollback(self):
+        self.await_(self._connection.rollback())
+
+    def close(self):
+        self.await_(self._connection.close())
+
+    @property
+    def autocommit(self):
+        return self._connection.autocommit
+
+    @autocommit.setter
+    def autocommit(self, value):
+        self.set_autocommit(value)
+
+    def set_autocommit(self, value):
+        self.await_(self._connection.set_autocommit(value))
+
+    def set_isolation_level(self, value):
+        self.await_(self._connection.set_isolation_level(value))
+
+    def set_read_only(self, value):
+        self.await_(self._connection.set_read_only(value))
+
+    def set_deferrable(self, value):
+        self.await_(self._connection.set_deferrable(value))
+
+
+class AsyncAdaptFallback_psycopg_connection(AsyncAdapt_psycopg_connection):
+    __slots__ = ()
+    await_ = staticmethod(await_fallback)
+
+
+class PsycopgAdaptDBAPI:
+    def __init__(self, psycopg) -> None:
+        self.psycopg = psycopg
+
+        for k, v in self.psycopg.__dict__.items():
+            if k != "connect":
+                self.__dict__[k] = v
+
+    def connect(self, *arg, **kw):
+        async_fallback = kw.pop("async_fallback", False)
+        if util.asbool(async_fallback):
+            return AsyncAdaptFallback_psycopg_connection(
+                await_fallback(
+                    self.psycopg.AsyncConnection.connect(*arg, **kw)
+                )
+            )
+        else:
+            return AsyncAdapt_psycopg_connection(
+                await_only(self.psycopg.AsyncConnection.connect(*arg, **kw))
+            )
+
+
+class PGDialectAsync_psycopg(PGDialect_psycopg):
+    is_async = True
+    supports_statement_cache = True
+
+    @classmethod
+    def dbapi(cls):
+        import psycopg
+        from psycopg.pq import ExecStatus
+
+        AsyncAdapt_psycopg_cursor._psycopg_ExecStatus = ExecStatus
+
+        return PsycopgAdaptDBAPI(psycopg)
+
+    @classmethod
+    def get_pool_class(cls, url):
+
+        async_fallback = url.query.get("async_fallback", False)
+
+        if util.asbool(async_fallback):
+            return pool.FallbackAsyncAdaptedQueuePool
+        else:
+            return pool.AsyncAdaptedQueuePool
+
+    def _type_info_fetch(self, connection, name):
+        from psycopg.types import TypeInfo
+
+        adapted = connection.connection
+        return adapted.await_(TypeInfo.fetch(adapted._connection, name))
+
+    def _do_isolation_level(self, connection, autocommit, isolation_level):
+        connection.set_autocommit(autocommit)
+        connection.set_isolation_level(isolation_level)
+
+    def _do_autocommit(self, connection, value):
+        connection.set_autocommit(value)
+
+    def set_readonly(self, connection, value):
+        connection.set_read_only(value)
+
+    def set_deferrable(self, connection, value):
+        connection.set_deferrable(value)
+
+    def get_driver_connection(self, connection):
+        return connection._connection
+
+
+dialect = PGDialect_psycopg
+dialect_async = PGDialectAsync_psycopg
index 39b6e0ed1e32ae35c8906efc72c57ae7c3638255..3d9f90a29795985d9a9582e496f9a503475d7cbe 100644 (file)
@@ -11,6 +11,8 @@ r"""
     :connectstring: postgresql+psycopg2://user:password@host:port/dbname[?key=value&key=value...]
     :url: https://pypi.org/project/psycopg2/
 
+.. _psycopg2_toplevel:
+
 psycopg2 Connect Arguments
 --------------------------
 
@@ -442,25 +444,15 @@ which may be more performant.
 
 """  # noqa
 import collections.abc as collections_abc
-import decimal
 import logging
 import re
-from uuid import UUID as _python_UUID
 
-from .array import ARRAY as PGARRAY
-from .base import _DECIMAL_TYPES
-from .base import _FLOAT_TYPES
-from .base import _INT_TYPES
+from ._psycopg_common import _PGDialect_common_psycopg
+from ._psycopg_common import _PGExecutionContext_common_psycopg
 from .base import PGCompiler
-from .base import PGDialect
-from .base import PGExecutionContext
 from .base import PGIdentifierPreparer
-from .base import UUID
-from .hstore import HSTORE
 from .json import JSON
 from .json import JSONB
-from ... import exc
-from ... import processors
 from ... import types as sqltypes
 from ... import util
 from ...engine import cursor as _cursor
@@ -469,53 +461,6 @@ from ...engine import cursor as _cursor
 logger = logging.getLogger("sqlalchemy.dialects.postgresql")
 
 
-class _PGNumeric(sqltypes.Numeric):
-    def bind_processor(self, dialect):
-        return None
-
-    def result_processor(self, dialect, coltype):
-        if self.asdecimal:
-            if coltype in _FLOAT_TYPES:
-                return processors.to_decimal_processor_factory(
-                    decimal.Decimal, self._effective_decimal_return_scale
-                )
-            elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
-                # pg8000 returns Decimal natively for 1700
-                return None
-            else:
-                raise exc.InvalidRequestError(
-                    "Unknown PG numeric type: %d" % coltype
-                )
-        else:
-            if coltype in _FLOAT_TYPES:
-                # pg8000 returns float natively for 701
-                return None
-            elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
-                return processors.to_float
-            else:
-                raise exc.InvalidRequestError(
-                    "Unknown PG numeric type: %d" % coltype
-                )
-
-
-class _PGHStore(HSTORE):
-    def bind_processor(self, dialect):
-        if dialect._has_native_hstore:
-            return None
-        else:
-            return super(_PGHStore, self).bind_processor(dialect)
-
-    def result_processor(self, dialect, coltype):
-        if dialect._has_native_hstore:
-            return None
-        else:
-            return super(_PGHStore, self).result_processor(dialect, coltype)
-
-
-class _PGARRAY(PGARRAY):
-    render_bind_cast = True
-
-
 class _PGJSON(JSON):
     def result_processor(self, dialect, coltype):
         return None
@@ -526,40 +471,9 @@ class _PGJSONB(JSONB):
         return None
 
 
-class _PGUUID(UUID):
-    def bind_processor(self, dialect):
-        if not self.as_uuid and dialect.use_native_uuid:
-
-            def process(value):
-                if value is not None:
-                    value = _python_UUID(value)
-                return value
-
-            return process
-
-    def result_processor(self, dialect, coltype):
-        if not self.as_uuid and dialect.use_native_uuid:
-
-            def process(value):
-                if value is not None:
-                    value = str(value)
-                return value
-
-            return process
-
-
-_server_side_id = util.counter()
-
-
-class PGExecutionContext_psycopg2(PGExecutionContext):
+class PGExecutionContext_psycopg2(_PGExecutionContext_common_psycopg):
     _psycopg2_fetched_rows = None
 
-    def create_server_side_cursor(self):
-        # use server-side cursors:
-        # https://lists.initd.org/pipermail/psycopg/2007-January/005251.html
-        ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:])
-        return self._dbapi_connection.cursor(ident)
-
     def post_exec(self):
         if (
             self._psycopg2_fetched_rows
@@ -614,7 +528,7 @@ EXECUTEMANY_VALUES_PLUS_BATCH = util.symbol(
 )
 
 
-class PGDialect_psycopg2(PGDialect):
+class PGDialect_psycopg2(_PGDialect_common_psycopg):
     driver = "psycopg2"
 
     supports_statement_cache = True
@@ -631,34 +545,22 @@ class PGDialect_psycopg2(PGDialect):
     _has_native_hstore = True
 
     colspecs = util.update_copy(
-        PGDialect.colspecs,
+        _PGDialect_common_psycopg.colspecs,
         {
-            sqltypes.Numeric: _PGNumeric,
-            HSTORE: _PGHStore,
             JSON: _PGJSON,
             sqltypes.JSON: _PGJSON,
             JSONB: _PGJSONB,
-            UUID: _PGUUID,
-            sqltypes.ARRAY: _PGARRAY,
         },
     )
 
     def __init__(
         self,
-        client_encoding=None,
-        use_native_hstore=True,
-        use_native_uuid=True,
         executemany_mode="values_only",
         executemany_batch_page_size=100,
         executemany_values_page_size=1000,
         **kwargs
     ):
-        PGDialect.__init__(self, **kwargs)
-        if not use_native_hstore:
-            self._has_native_hstore = False
-        self.use_native_hstore = use_native_hstore
-        self.use_native_uuid = use_native_uuid
-        self.client_encoding = client_encoding
+        _PGDialect_common_psycopg.__init__(self, **kwargs)
 
         # Parse executemany_mode argument, allowing it to be only one of the
         # symbol names
@@ -737,9 +639,6 @@ class PGDialect_psycopg2(PGDialect):
             "SERIALIZABLE": extensions.ISOLATION_LEVEL_SERIALIZABLE,
         }
 
-    def get_isolation_level_values(self, dbapi_conn):
-        return list(self._isolation_lookup)
-
     def set_isolation_level(self, connection, level):
         connection.set_isolation_level(self._isolation_lookup[level])
 
@@ -755,25 +654,6 @@ class PGDialect_psycopg2(PGDialect):
     def get_deferrable(self, connection):
         return connection.deferrable
 
-    def do_ping(self, dbapi_connection):
-        cursor = None
-        try:
-            dbapi_connection.autocommit = True
-            cursor = dbapi_connection.cursor()
-            try:
-                cursor.execute(self._dialect_specific_select_one)
-            finally:
-                cursor.close()
-                if not dbapi_connection.closed:
-                    dbapi_connection.autocommit = False
-        except self.dbapi.Error as err:
-            if self.is_disconnect(err, dbapi_connection, cursor):
-                return False
-            else:
-                raise
-        else:
-            return True
-
     def on_connect(self):
         extras = self._psycopg2_extras
 
@@ -911,33 +791,6 @@ class PGDialect_psycopg2(PGDialect):
         else:
             return None
 
-    def create_connect_args(self, url):
-        opts = url.translate_connect_args(username="user")
-
-        is_multihost = False
-        if "host" in url.query:
-            is_multihost = isinstance(url.query["host"], (list, tuple))
-
-        if opts:
-            if "port" in opts:
-                opts["port"] = int(opts["port"])
-            opts.update(url.query)
-            if is_multihost:
-                opts["host"] = ",".join(url.query["host"])
-            # send individual dbname, user, password, host, port
-            # parameters to psycopg2.connect()
-            return ([], opts)
-        elif url.query:
-            # any other connection arguments, pass directly
-            opts.update(url.query)
-            if is_multihost:
-                opts["host"] = ",".join(url.query["host"])
-            return ([], opts)
-        else:
-            # no connection arguments whatsoever; psycopg2.connect()
-            # requires that "dsn" be present as a blank string.
-            return ([""], opts)
-
     def is_disconnect(self, e, connection, cursor):
         if isinstance(e, self.dbapi.Error):
             # check the "closed" flag.  this might not be
index f9a65a0f8d164413eaac9dc26e9ce770e06456f8..8fcba7503892507897aa2405f1314c5193996e1d 100644 (file)
@@ -458,7 +458,11 @@ def create_engine(url, **kwargs):
     u, plugins, kwargs = u._instantiate_plugins(kwargs)
 
     entrypoint = u._get_entrypoint()
-    dialect_cls = entrypoint.get_dialect_cls(u)
+    _is_async = kwargs.pop("_is_async", False)
+    if _is_async:
+        dialect_cls = entrypoint.get_async_dialect_cls(u)
+    else:
+        dialect_cls = entrypoint.get_dialect_cls(u)
 
     if kwargs.pop("_coerce_config", False):
 
index 1f1a2fcf1e69065d5ef5a442d798137c64ee78b4..7f2b8b412c30111b5d1c8874f966cd3509d405a8 100644 (file)
@@ -1223,6 +1223,7 @@ class BaseCursorResult:
 
 
         """
+
         if (not hard and self._soft_closed) or (hard and self.closed):
             return
 
index cb04eb525b9473a81d7a71d0437fd34fa37bdcf5..64500b41baaa1865dd50a964913a94c5967bcb03 100644 (file)
@@ -1284,6 +1284,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
         result.out_parameters = out_parameters
 
     def _setup_dml_or_text_result(self):
+
         if self.isinsert:
             if self.compiled.postfetch_lastrowid:
                 self.inserted_primary_key_rows = (
@@ -1332,8 +1333,17 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
                 # assert not result.returns_rows
 
         elif self.isupdate and self._is_implicit_returning:
+            # get rowcount
+            # (which requires open cursor on some drivers)
+            # we were not doing this in 1.4, however
+            # test_rowcount -> test_update_rowcount_return_defaults
+            # is testing this, and psycopg will no longer return
+            # rowcount after cursor is closed.
+            result.rowcount
+
             row = result.fetchone()
             self.returned_default_rows = [row]
+
             result._soft_close()
 
             # test that it has a cursor metadata that is accurate.
@@ -1410,6 +1420,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
 
         dialect = self.dialect
 
+        # all of the rest of this... cython?
+
         if dialect._has_events:
             inputsizes = dict(inputsizes)
             dialect.dispatch.do_setinputsizes(
index 251d01c5e21b41b20eebb7f7e9fe251816514f05..faaf073ab07484dfaa25ada11433fb3b86494843 100644 (file)
@@ -1113,6 +1113,25 @@ class Dialect:
         """
         return cls
 
+    @classmethod
+    def get_async_dialect_cls(cls, url):
+        """Given a URL, return the :class:`.Dialect` that will be used by
+        an async engine.
+
+        By default this is an alias of :meth:`.Dialect.get_dialect_cls` and
+        just returns the cls. It may be used if a dialect provides
+        both a sync and async version under the same name, like the
+        ``psycopg`` driver.
+
+        .. versionadded:: 2
+
+        .. seealso::
+
+            :meth:`.Dialect.get_dialect_cls`
+
+        """
+        return cls.get_dialect_cls(url)
+
     @classmethod
     def load_provisioning(cls):
         """set up the provision.py module for this dialect.
index c83753bdc95c13ef4bdf2e08e2b1427a39e06f64..7cdf25c21aadc6e4937c53dc98230b60f9accb9f 100644 (file)
@@ -655,13 +655,16 @@ class URL(
         else:
             return cls
 
-    def get_dialect(self):
+    def get_dialect(self, _is_async=False):
         """Return the SQLAlchemy :class:`_engine.Dialect` class corresponding
         to this URL's driver name.
 
         """
         entrypoint = self._get_entrypoint()
-        dialect_cls = entrypoint.get_dialect_cls(self)
+        if _is_async:
+            dialect_cls = entrypoint.get_async_dialect_cls(self)
+        else:
+            dialect_cls = entrypoint.get_dialect_cls(self)
         return dialect_cls
 
     def translate_connect_args(self, names=None, **kw):
index 221d82f088acae17e1a197ba499bb2a6d52c8278..67d944b9ce608f0ce49de2389772b2244523975e 100644 (file)
@@ -37,6 +37,7 @@ def create_async_engine(*arg, **kw):
             "streaming result set"
         )
     kw["future"] = True
+    kw["_is_async"] = True
     sync_engine = _create_engine(*arg, **kw)
     return AsyncEngine(sync_engine)
 
index 8874f0a83defd851ce7b8e2c1594844263c517a8..427251a88513fd9e016bd6d3d4e5c5e09bc5681b 100644 (file)
@@ -2338,26 +2338,40 @@ class JSON(Indexable, TypeEngine):
     def _str_impl(self):
         return String()
 
-    def bind_processor(self, dialect):
-        string_process = self._str_impl.bind_processor(dialect)
+    def _make_bind_processor(self, string_process, json_serializer):
+        if string_process:
 
-        json_serializer = dialect._json_serializer or json.dumps
+            def process(value):
+                if value is self.NULL:
+                    value = None
+                elif isinstance(value, elements.Null) or (
+                    value is None and self.none_as_null
+                ):
+                    return None
 
-        def process(value):
-            if value is self.NULL:
-                value = None
-            elif isinstance(value, elements.Null) or (
-                value is None and self.none_as_null
-            ):
-                return None
+                serialized = json_serializer(value)
+                return string_process(serialized)
 
-            serialized = json_serializer(value)
-            if string_process:
-                serialized = string_process(serialized)
-            return serialized
+        else:
+
+            def process(value):
+                if value is self.NULL:
+                    value = None
+                elif isinstance(value, elements.Null) or (
+                    value is None and self.none_as_null
+                ):
+                    return None
+
+                return json_serializer(value)
 
         return process
 
+    def bind_processor(self, dialect):
+        string_process = self._str_impl.bind_processor(dialect)
+        json_serializer = dialect._json_serializer or json.dumps
+
+        return self._make_bind_processor(string_process, json_serializer)
+
     def result_processor(self, dialect, coltype):
         string_process = self._str_impl.result_processor(dialect, coltype)
         json_deserializer = dialect._json_deserializer or json.loads
index 8faeea6341b1efe958087ee99cc5e2ff60cae681..22d9c523a268f2c3aaf57a6c583eda59c4898fd0 100644 (file)
@@ -124,11 +124,12 @@ class Config:
     _configs = set()
 
     def _set_name(self, db):
+        suffix = "_async" if db.dialect.is_async else ""
         if db.dialect.server_version_info:
             svi = ".".join(str(tok) for tok in db.dialect.server_version_info)
-            self.name = "%s+%s_[%s]" % (db.name, db.driver, svi)
+            self.name = "%s+%s%s_[%s]" % (db.name, db.driver, suffix, svi)
         else:
-            self.name = "%s+%s" % (db.name, db.driver)
+            self.name = "%s+%s%s" % (db.name, db.driver, suffix)
 
     @classmethod
     def register(cls, db, db_opts, options, file_config):
index 32ed2c31593f84c49b439e83df1f5d24135ac0ae..d79931b91ea7e6976290304ab8839ec94aeabe3b 100644 (file)
@@ -706,7 +706,8 @@ def _do_skips(cls):
                 )
 
     if not all_configs:
-        msg = "'%s' unsupported on any DB implementation %s%s" % (
+        msg = "'%s.%s' unsupported on any DB implementation %s%s" % (
+            cls.__module__,
             cls.__name__,
             ", ".join(
                 "'%s(%s)+%s'"
index 8cb72d16336052d4a1a0907f02f344876d1874d6..f811be657325d6732d36d4f2dc3042837ee3fa8c 100644 (file)
@@ -1493,3 +1493,13 @@ class SuiteRequirements(Requirements):
         sequence. This should be false only for oracle.
         """
         return exclusions.open()
+
+    @property
+    def generic_classes(self):
+        "If X[Y] can be implemented with ``__class_getitem__``. py3.7+"
+        return exclusions.open()
+
+    @property
+    def json_deserializer_binary(self):
+        "indicates if the json_deserializer function is called with bytes"
+        return exclusions.closed()
index 5ad68034b9a89d33777d2bc260b67bb8fcc4ea8b..a8900ece148961ee2d9710434e630f31f9da1fa8 100644 (file)
@@ -244,6 +244,8 @@ class ServerSideCursorsTest(
             return cursor.server_side
         elif self.engine.dialect.driver == "pg8000":
             return getattr(cursor, "server_side", False)
+        elif self.engine.dialect.driver == "psycopg":
+            return bool(getattr(cursor, "name", False))
         else:
             return False
 
index 82e6fa238394f7316cff774da083e96583a332ff..e7131ec6ea66402d920f55aa4abfec0cdcce67ac 100644 (file)
@@ -1107,7 +1107,13 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
 
             eq_(row, (data_element,))
             eq_(js.mock_calls, [mock.call(data_element)])
-            eq_(jd.mock_calls, [mock.call(json.dumps(data_element))])
+            if testing.requires.json_deserializer_binary.enabled:
+                eq_(
+                    jd.mock_calls,
+                    [mock.call(json.dumps(data_element).encode())],
+                )
+            else:
+                eq_(jd.mock_calls, [mock.call(json.dumps(data_element))])
 
     @testing.combinations(
         ("parameters",),
index 624ace1bdf1690c1992343f7437d3ad01fd17866..80477f8a4e753ea091d16ec43f2e9d23e83c2662 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -66,6 +66,7 @@ postgresql_asyncpg =
     asyncpg;python_version>="3"
 postgresql_psycopg2binary = psycopg2-binary
 postgresql_psycopg2cffi = psycopg2cffi
+postgresql_psycopg = psycopg>=3.0.2
 pymysql =
     pymysql;python_version>="3"
     pymysql<1;python_version<"3"
@@ -157,6 +158,10 @@ sqlite_file = sqlite:///querytest.db
 aiosqlite_file = sqlite+aiosqlite:///async_querytest.db
 pysqlcipher_file = sqlite+pysqlcipher://:test@/querytest.db.enc
 postgresql = postgresql+psycopg2://scott:tiger@127.0.0.1:5432/test
+psycopg2 = postgresql+psycopg2://scott:tiger@127.0.0.1:5432/test
+psycopg = postgresql+psycopg://scott:tiger@127.0.0.1:5432/test
+psycopg_async = postgresql+psycopg_async://scott:tiger@127.0.0.1:5432/test
+psycopg_async_fallback = postgresql+psycopg_async://scott:tiger@127.0.0.1:5432/test?async_fallback=true
 asyncpg = postgresql+asyncpg://scott:tiger@127.0.0.1:5432/test
 asyncpg_fallback = postgresql+asyncpg://scott:tiger@127.0.0.1:5432/test?async_fallback=true
 pg8000 = postgresql+pg8000://scott:tiger@127.0.0.1:5432/test
index 57682686c8637e5ceab0d3d898e01ab998fdef5f..02d7ad483ec5b3ff82a536212600432b66786e41 100644 (file)
@@ -8,6 +8,7 @@ from sqlalchemy import BigInteger
 from sqlalchemy import bindparam
 from sqlalchemy import cast
 from sqlalchemy import Column
+from sqlalchemy import create_engine
 from sqlalchemy import DateTime
 from sqlalchemy import DDL
 from sqlalchemy import event
@@ -30,6 +31,9 @@ from sqlalchemy import text
 from sqlalchemy import TypeDecorator
 from sqlalchemy import util
 from sqlalchemy.dialects.postgresql import base as postgresql
+from sqlalchemy.dialects.postgresql import HSTORE
+from sqlalchemy.dialects.postgresql import JSONB
+from sqlalchemy.dialects.postgresql import psycopg as psycopg_dialect
 from sqlalchemy.dialects.postgresql import psycopg2 as psycopg2_dialect
 from sqlalchemy.dialects.postgresql.psycopg2 import EXECUTEMANY_BATCH
 from sqlalchemy.dialects.postgresql.psycopg2 import EXECUTEMANY_PLAIN
@@ -269,10 +273,12 @@ class PGCodeTest(fixtures.TestBase):
         if testing.against("postgresql+pg8000"):
             # TODO: is there another way we're supposed to see this?
             eq_(errmsg.orig.args[0]["C"], "23505")
-        else:
+        elif not testing.against("postgresql+psycopg"):
             eq_(errmsg.orig.pgcode, "23505")
 
-        if testing.against("postgresql+asyncpg"):
+        if testing.against("postgresql+asyncpg") or testing.against(
+            "postgresql+psycopg"
+        ):
             eq_(errmsg.orig.sqlstate, "23505")
 
 
@@ -858,6 +864,13 @@ class MiscBackendTest(
             ".".join(str(x) for x in v)
         )
 
+    @testing.only_on("postgresql+psycopg")
+    def test_psycopg_version(self):
+        v = testing.db.dialect.psycopg_version
+        assert testing.db.dialect.dbapi.__version__.startswith(
+            ".".join(str(x) for x in v)
+        )
+
     @testing.combinations(
         ((8, 1), False, False),
         ((8, 1), None, False),
@@ -902,6 +915,7 @@ class MiscBackendTest(
         with testing.db.connect().execution_options(
             isolation_level="SERIALIZABLE"
         ) as conn:
+
             dbapi_conn = conn.connection.dbapi_connection
 
             is_false(dbapi_conn.autocommit)
@@ -1069,25 +1083,30 @@ class MiscBackendTest(
                 dbapi_conn.rollback()
             eq_(val, "off")
 
-    @testing.requires.psycopg2_compatibility
-    def test_psycopg2_non_standard_err(self):
+    @testing.requires.psycopg_compatibility
+    def test_psycopg_non_standard_err(self):
         # note that psycopg2 is sometimes called psycopg2cffi
         # depending on platform
-        psycopg2 = testing.db.dialect.dbapi
-        TransactionRollbackError = __import__(
-            "%s.extensions" % psycopg2.__name__
-        ).extensions.TransactionRollbackError
+        psycopg = testing.db.dialect.dbapi
+        if psycopg.__version__.startswith("3"):
+            TransactionRollbackError = __import__(
+                "%s.errors" % psycopg.__name__
+            ).errors.TransactionRollback
+        else:
+            TransactionRollbackError = __import__(
+                "%s.extensions" % psycopg.__name__
+            ).extensions.TransactionRollbackError
 
         exception = exc.DBAPIError.instance(
             "some statement",
             {},
             TransactionRollbackError("foo"),
-            psycopg2.Error,
+            psycopg.Error,
         )
         assert isinstance(exception, exc.OperationalError)
 
     @testing.requires.no_coverage
-    @testing.requires.psycopg2_compatibility
+    @testing.requires.psycopg_compatibility
     def test_notice_logging(self):
         log = logging.getLogger("sqlalchemy.dialects.postgresql")
         buf = logging.handlers.BufferingHandler(100)
@@ -1115,14 +1134,14 @@ $$ LANGUAGE plpgsql;
         finally:
             log.removeHandler(buf)
             log.setLevel(lev)
-        msgs = " ".join(b.msg for b in buf.buffer)
+        msgs = " ".join(b.getMessage() for b in buf.buffer)
         eq_regex(
             msgs,
-            "NOTICE:  notice: hi there(\nCONTEXT: .*?)? "
-            "NOTICE:  notice: another note(\nCONTEXT: .*?)?",
+            "NOTICE: [ ]?notice: hi there(\nCONTEXT: .*?)? "
+            "NOTICE: [ ]?notice: another note(\nCONTEXT: .*?)?",
         )
 
-    @testing.requires.psycopg2_or_pg8000_compatibility
+    @testing.requires.psycopg_or_pg8000_compatibility
     @engines.close_open_connections
     def test_client_encoding(self):
         c = testing.db.connect()
@@ -1143,7 +1162,7 @@ $$ LANGUAGE plpgsql;
         new_encoding = c.exec_driver_sql("show client_encoding").fetchone()[0]
         eq_(new_encoding, test_encoding)
 
-    @testing.requires.psycopg2_or_pg8000_compatibility
+    @testing.requires.psycopg_or_pg8000_compatibility
     @engines.close_open_connections
     def test_autocommit_isolation_level(self):
         c = testing.db.connect().execution_options(
@@ -1302,7 +1321,7 @@ $$ LANGUAGE plpgsql;
         assert result == [(1, "user", "lala")]
         connection.execute(text("DROP TABLE speedy_users"))
 
-    @testing.requires.psycopg2_or_pg8000_compatibility
+    @testing.requires.psycopg_or_pg8000_compatibility
     def test_numeric_raise(self, connection):
         stmt = text("select cast('hi' as char) as hi").columns(hi=Numeric)
         assert_raises(exc.InvalidRequestError, connection.execute, stmt)
@@ -1364,9 +1383,90 @@ $$ LANGUAGE plpgsql;
             )
 
     @testing.requires.psycopg2_compatibility
-    def test_initial_transaction_state(self):
+    def test_initial_transaction_state_psycopg2(self):
         from psycopg2.extensions import STATUS_IN_TRANSACTION
 
         engine = engines.testing_engine()
         with engine.connect() as conn:
             ne_(conn.connection.status, STATUS_IN_TRANSACTION)
+
+    @testing.only_on("postgresql+psycopg")
+    def test_initial_transaction_state_psycopg(self):
+        from psycopg.pq import TransactionStatus
+
+        engine = engines.testing_engine()
+        with engine.connect() as conn:
+            ne_(
+                conn.connection.dbapi_connection.info.transaction_status,
+                TransactionStatus.INTRANS,
+            )
+
+
+class Psycopg3Test(fixtures.TestBase):
+    __only_on__ = ("postgresql+psycopg",)
+
+    def test_json_correctly_registered(self, testing_engine):
+        import json
+
+        def loads(value):
+            value = json.loads(value)
+            value["x"] = value["x"] + "_loads"
+            return value
+
+        def dumps(value):
+            value = dict(value)
+            value["x"] = "dumps_y"
+            return json.dumps(value)
+
+        engine = testing_engine(
+            options=dict(json_serializer=dumps, json_deserializer=loads)
+        )
+        engine2 = testing_engine(
+            options=dict(
+                json_serializer=json.dumps, json_deserializer=json.loads
+            )
+        )
+
+        s = select(cast({"key": "value", "x": "q"}, JSONB))
+        with engine.begin() as conn:
+            eq_(conn.scalar(s), {"key": "value", "x": "dumps_y_loads"})
+            with engine.begin() as conn:
+                eq_(conn.scalar(s), {"key": "value", "x": "dumps_y_loads"})
+                with engine2.begin() as conn:
+                    eq_(conn.scalar(s), {"key": "value", "x": "q"})
+                with engine.begin() as conn:
+                    eq_(conn.scalar(s), {"key": "value", "x": "dumps_y_loads"})
+
+    @testing.requires.hstore
+    def test_hstore_correctly_registered(self, testing_engine):
+        engine = testing_engine(options=dict(use_native_hstore=True))
+        engine2 = testing_engine(options=dict(use_native_hstore=False))
+
+        def rp(self, *a):
+            return lambda a: {"a": "b"}
+
+        with mock.patch.object(HSTORE, "result_processor", side_effect=rp):
+            s = select(cast({"key": "value", "x": "q"}, HSTORE))
+            with engine.begin() as conn:
+                eq_(conn.scalar(s), {"key": "value", "x": "q"})
+                with engine.begin() as conn:
+                    eq_(conn.scalar(s), {"key": "value", "x": "q"})
+                    with engine2.begin() as conn:
+                        eq_(conn.scalar(s), {"a": "b"})
+                    with engine.begin() as conn:
+                        eq_(conn.scalar(s), {"key": "value", "x": "q"})
+
+    def test_get_dialect(self):
+        u = url.URL.create("postgresql://")
+        d = psycopg_dialect.PGDialect_psycopg.get_dialect_cls(u)
+        is_(d, psycopg_dialect.PGDialect_psycopg)
+        d = psycopg_dialect.PGDialect_psycopg.get_async_dialect_cls(u)
+        is_(d, psycopg_dialect.PGDialectAsync_psycopg)
+        d = psycopg_dialect.PGDialectAsync_psycopg.get_dialect_cls(u)
+        is_(d, psycopg_dialect.PGDialectAsync_psycopg)
+        d = psycopg_dialect.PGDialectAsync_psycopg.get_dialect_cls(u)
+        is_(d, psycopg_dialect.PGDialectAsync_psycopg)
+
+    def test_async_version(self):
+        e = create_engine("postgresql+psycopg_async://")
+        is_true(isinstance(e.dialect, psycopg_dialect.PGDialectAsync_psycopg))
index b488b146cf70248a89f83a6753e870c9f2967cba..fdce643f840f0ce9c48a28ef20679da1e2e4771e 100644 (file)
@@ -14,6 +14,7 @@ from sqlalchemy import Float
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import Integer
+from sqlalchemy import JSON
 from sqlalchemy import literal
 from sqlalchemy import literal_column
 from sqlalchemy import MetaData
@@ -29,6 +30,7 @@ from sqlalchemy import true
 from sqlalchemy import tuple_
 from sqlalchemy.dialects import postgresql
 from sqlalchemy.dialects.postgresql import JSONB
+from sqlalchemy.sql.expression import type_coerce
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import AssertsExecutionResults
@@ -40,6 +42,17 @@ from sqlalchemy.testing.assertsql import CursorSQL
 from sqlalchemy.testing.assertsql import DialectSQL
 
 
+class FunctionTypingTest(fixtures.TestBase, AssertsExecutionResults):
+    __only_on__ = "postgresql"
+    __backend__ = True
+
+    def test_count_star(self, connection):
+        eq_(connection.scalar(func.count("*")), 1)
+
+    def test_count_int(self, connection):
+        eq_(connection.scalar(func.count(1)), 1)
+
+
 class InsertTest(fixtures.TestBase, AssertsExecutionResults):
 
     __only_on__ = "postgresql"
@@ -956,23 +969,42 @@ class MatchTest(fixtures.TablesTest, AssertsCompiledSQL):
             ],
         )
 
+    def _strs_render_bind_casts(self, connection):
+
+        return (
+            connection.dialect._bind_typing_render_casts
+            and String().dialect_impl(connection.dialect).render_bind_cast
+        )
+
     @testing.requires.pyformat_paramstyle
-    def test_expression_pyformat(self):
+    def test_expression_pyformat(self, connection):
         matchtable = self.tables.matchtable
-        self.assert_compile(
-            matchtable.c.title.match("somstr"),
-            "matchtable.title @@ to_tsquery(%(title_1)s" ")",
-        )
+
+        if self._strs_render_bind_casts(connection):
+            self.assert_compile(
+                matchtable.c.title.match("somstr"),
+                "matchtable.title @@ to_tsquery(%(title_1)s::VARCHAR(200))",
+            )
+        else:
+            self.assert_compile(
+                matchtable.c.title.match("somstr"),
+                "matchtable.title @@ to_tsquery(%(title_1)s)",
+            )
 
     @testing.requires.format_paramstyle
-    def test_expression_positional(self):
+    def test_expression_positional(self, connection):
         matchtable = self.tables.matchtable
-        self.assert_compile(
-            matchtable.c.title.match("somstr"),
-            # note we assume current tested DBAPIs use emulated setinputsizes
-            # here, the cast is not strictly necessary
-            "matchtable.title @@ to_tsquery(%s::VARCHAR(200))",
-        )
+
+        if self._strs_render_bind_casts(connection):
+            self.assert_compile(
+                matchtable.c.title.match("somstr"),
+                "matchtable.title @@ to_tsquery(%s::VARCHAR(200))",
+            )
+        else:
+            self.assert_compile(
+                matchtable.c.title.match("somstr"),
+                "matchtable.title @@ to_tsquery(%s)",
+            )
 
     def test_simple_match(self, connection):
         matchtable = self.tables.matchtable
@@ -1551,17 +1583,106 @@ class TableValuedRoundTripTest(fixtures.TestBase):
             [(14, 1), (41, 2), (7, 3), (54, 4), (9, 5), (49, 6)],
         )
 
-    @testing.only_on(
-        "postgresql+psycopg2",
-        "I cannot get this to run at all on other drivers, "
-        "even selecting from a table",
+    @testing.combinations(
+        (
+            type_coerce,
+            testing.fails("fails on all drivers"),
+        ),
+        (
+            cast,
+            testing.fails("fails on all drivers"),
+        ),
+        (
+            None,
+            testing.fails_on_everything_except(
+                ["postgresql+psycopg2"],
+                "I cannot get this to run at all on other drivers, "
+                "even selecting from a table",
+            ),
+        ),
+        argnames="cast_fn",
     )
-    def test_render_derived_quoting(self, connection):
+    def test_render_derived_quoting_text(self, connection, cast_fn):
+
+        value = (
+            '[{"CaseSensitive":1,"the % value":"foo"}, '
+            '{"CaseSensitive":"2","the % value":"bar"}]'
+        )
+
+        if cast_fn:
+            value = cast_fn(value, JSON)
+
         fn = (
-            func.json_to_recordset(  # noqa
-                '[{"CaseSensitive":1,"the % value":"foo"}, '
-                '{"CaseSensitive":"2","the % value":"bar"}]'
+            func.json_to_recordset(value)
+            .table_valued(
+                column("CaseSensitive", Integer), column("the % value", String)
             )
+            .render_derived(with_types=True)
+        )
+
+        stmt = select(fn.c.CaseSensitive, fn.c["the % value"])
+
+        eq_(connection.execute(stmt).all(), [(1, "foo"), (2, "bar")])
+
+    @testing.combinations(
+        (
+            type_coerce,
+            testing.fails("fails on all drivers"),
+        ),
+        (
+            cast,
+            testing.fails("fails on all drivers"),
+        ),
+        (
+            None,
+            testing.fails("Fails on all drivers"),
+        ),
+        argnames="cast_fn",
+    )
+    def test_render_derived_quoting_text_to_json(self, connection, cast_fn):
+
+        value = (
+            '[{"CaseSensitive":1,"the % value":"foo"}, '
+            '{"CaseSensitive":"2","the % value":"bar"}]'
+        )
+
+        if cast_fn:
+            value = cast_fn(value, JSON)
+
+        # why wont this work?!?!?
+        # should be exactly json_to_recordset(to_json('string'::text))
+        #
+        fn = (
+            func.json_to_recordset(func.to_json(value))
+            .table_valued(
+                column("CaseSensitive", Integer), column("the % value", String)
+            )
+            .render_derived(with_types=True)
+        )
+
+        stmt = select(fn.c.CaseSensitive, fn.c["the % value"])
+
+        eq_(connection.execute(stmt).all(), [(1, "foo"), (2, "bar")])
+
+    @testing.combinations(
+        (type_coerce,),
+        (cast,),
+        (None, testing.fails("Fails on all PG backends")),
+        argnames="cast_fn",
+    )
+    def test_render_derived_quoting_straight_json(self, connection, cast_fn):
+        # these all work
+
+        value = [
+            {"CaseSensitive": 1, "the % value": "foo"},
+            {"CaseSensitive": "2", "the % value": "bar"},
+        ]
+
+        if cast_fn:
+            value = cast_fn(value, JSON)
+
+        fn = (
+            func.json_to_recordset(value)  # noqa
             .table_valued(
                 column("CaseSensitive", Integer), column("the % value", String)
             )
index 4008881d216c19186e7cebce9ec10a7a5c8a40a8..5f8a41d1f831ff8178dabd3de966f102f647803e 100644 (file)
@@ -2360,7 +2360,7 @@ class ArrayEnum(fixtures.TestBase):
             testing.combinations(
                 sqltypes.ARRAY,
                 postgresql.ARRAY,
-                (_ArrayOfEnum, testing.only_on("postgresql+psycopg2")),
+                (_ArrayOfEnum, testing.requires.psycopg_compatibility),
                 argnames="array_cls",
             )(fn)
         )
@@ -3066,7 +3066,7 @@ class HStoreRoundTripTest(fixtures.TablesTest):
 
     @testing.fixture
     def non_native_hstore_connection(self, testing_engine):
-        local_engine = testing.requires.psycopg2_native_hstore.enabled
+        local_engine = testing.requires.native_hstore.enabled
 
         if local_engine:
             engine = testing_engine(options=dict(use_native_hstore=False))
@@ -3096,14 +3096,14 @@ class HStoreRoundTripTest(fixtures.TablesTest):
         )["1"]
         eq_(connection.scalar(select(expr)), "3")
 
-    @testing.requires.psycopg2_native_hstore
+    @testing.requires.native_hstore
     def test_insert_native(self, connection):
         self._test_insert(connection)
 
     def test_insert_python(self, non_native_hstore_connection):
         self._test_insert(non_native_hstore_connection)
 
-    @testing.requires.psycopg2_native_hstore
+    @testing.requires.native_hstore
     def test_criterion_native(self, connection):
         self._fixture_data(connection)
         self._test_criterion(connection)
@@ -3134,7 +3134,7 @@ class HStoreRoundTripTest(fixtures.TablesTest):
     def test_fixed_round_trip_python(self, non_native_hstore_connection):
         self._test_fixed_round_trip(non_native_hstore_connection)
 
-    @testing.requires.psycopg2_native_hstore
+    @testing.requires.native_hstore
     def test_fixed_round_trip_native(self, connection):
         self._test_fixed_round_trip(connection)
 
@@ -3154,11 +3154,11 @@ class HStoreRoundTripTest(fixtures.TablesTest):
             },
         )
 
-    @testing.requires.psycopg2_native_hstore
+    @testing.requires.native_hstore
     def test_unicode_round_trip_python(self, non_native_hstore_connection):
         self._test_unicode_round_trip(non_native_hstore_connection)
 
-    @testing.requires.psycopg2_native_hstore
+    @testing.requires.native_hstore
     def test_unicode_round_trip_native(self, connection):
         self._test_unicode_round_trip(connection)
 
@@ -3167,7 +3167,7 @@ class HStoreRoundTripTest(fixtures.TablesTest):
     ):
         self._test_escaped_quotes_round_trip(non_native_hstore_connection)
 
-    @testing.requires.psycopg2_native_hstore
+    @testing.requires.native_hstore
     def test_escaped_quotes_round_trip_native(self, connection):
         self._test_escaped_quotes_round_trip(connection)
 
@@ -3356,7 +3356,7 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
 
 
 class _RangeTypeRoundTrip(fixtures.TablesTest):
-    __requires__ = "range_types", "psycopg2_compatibility"
+    __requires__ = "range_types", "psycopg_compatibility"
     __backend__ = True
 
     def extras(self):
@@ -3364,8 +3364,18 @@ class _RangeTypeRoundTrip(fixtures.TablesTest):
         # older psycopg2 versions.
         if testing.against("postgresql+psycopg2cffi"):
             from psycopg2cffi import extras
-        else:
+        elif testing.against("postgresql+psycopg2"):
             from psycopg2 import extras
+        elif testing.against("postgresql+psycopg"):
+            from psycopg.types.range import Range
+
+            class psycopg_extras:
+                def __getattr__(self, _):
+                    return Range
+
+            extras = psycopg_extras()
+        else:
+            assert False, "Unknonw dialect"
         return extras
 
     @classmethod
index c69b332a8f7c5bc3f7ade3fc328b7266048bfc16..f12d32d5d0e558cf12d9375288fafd0c43ec9673 100644 (file)
@@ -1037,6 +1037,64 @@ class TestRegNewDBAPI(fixtures.TestBase):
         )
 
 
+class TestGetDialect(fixtures.TestBase):
+    @testing.requires.sqlite
+    @testing.combinations(True, False, None)
+    def test_is_async_to_create_engine(self, is_async):
+        def get_dialect_cls(url):
+            url = url.set(drivername="sqlite")
+            return url.get_dialect()
+
+        global MockDialectGetDialect
+        MockDialectGetDialect = Mock()
+        MockDialectGetDialect.get_dialect_cls.side_effect = get_dialect_cls
+        MockDialectGetDialect.get_async_dialect_cls.side_effect = (
+            get_dialect_cls
+        )
+
+        registry.register("mockdialect", __name__, "MockDialectGetDialect")
+
+        from sqlalchemy.dialects import sqlite
+
+        kw = {}
+        if is_async is not None:
+            kw["_is_async"] = is_async
+        e = create_engine("mockdialect://", **kw)
+
+        eq_(e.dialect.name, "sqlite")
+        assert isinstance(e.dialect, sqlite.dialect)
+
+        if is_async:
+            eq_(
+                MockDialectGetDialect.mock_calls,
+                [
+                    call.get_async_dialect_cls(url.make_url("mockdialect://")),
+                    call.engine_created(e),
+                ],
+            )
+        else:
+            eq_(
+                MockDialectGetDialect.mock_calls,
+                [
+                    call.get_dialect_cls(url.make_url("mockdialect://")),
+                    call.engine_created(e),
+                ],
+            )
+        MockDialectGetDialect.reset_mock()
+        u = url.make_url("mockdialect://")
+        u.get_dialect(**kw)
+        if is_async:
+            eq_(
+                MockDialectGetDialect.mock_calls,
+                [call.get_async_dialect_cls(u)],
+            )
+        else:
+            eq_(
+                MockDialectGetDialect.mock_calls,
+                [call.get_dialect_cls(u)],
+            )
+
+
 class MockDialect(DefaultDialect):
     @classmethod
     def dbapi(cls, **kw):
index afd027698e68d29406af7c97f12b2866b4ac3de4..f37d1893f0f0f3e5e8dc5634ebf2f316b79e3bf7 100644 (file)
@@ -365,7 +365,7 @@ class MockReconnectTest(fixtures.TestBase):
         )
 
         self.mock_connect = call(
-            host="localhost", password="bar", user="foo", database="test"
+            host="localhost", password="bar", user="foo", dbname="test"
         )
         # monkeypatch disconnect checker
         self.db.dialect.is_disconnect = lambda e, conn, cursor: isinstance(
@@ -1321,6 +1321,7 @@ class InvalidateDuringResultTest(fixtures.TestBase):
             "+aiosqlite",
             "+aiomysql",
             "+asyncmy",
+            "+psycopg",
         ],
         "Buffers the result set and doesn't check for connection close",
     )
index 44cf9388ca08b0c0ba8b164405763d6c9924bc46..aee71f8d5503cbd1f5e74eee2a63f791c1029330 100644 (file)
@@ -634,7 +634,11 @@ class AsyncEventTest(EngineFixture):
 
         eq_(
             canary.mock_calls,
-            [mock.call(sync_conn, mock.ANY, "select 1", (), mock.ANY, False)],
+            [
+                mock.call(
+                    sync_conn, mock.ANY, "select 1", mock.ANY, mock.ANY, False
+                )
+            ],
         )
 
     @async_test
@@ -651,7 +655,11 @@ class AsyncEventTest(EngineFixture):
 
         eq_(
             canary.mock_calls,
-            [mock.call(sync_conn, mock.ANY, "select 1", (), mock.ANY, False)],
+            [
+                mock.call(
+                    sync_conn, mock.ANY, "select 1", mock.ANY, mock.ANY, False
+                )
+            ],
         )
 
     @async_test
index eeb71323bc59cea676810940216ac674cba6f2f7..3934dd23fc596e6edd55f005f942b38eac7a7b5e 100644 (file)
@@ -223,6 +223,7 @@ class DefaultRequirements(SuiteRequirements):
     def pyformat_paramstyle(self):
         return only_on(
             [
+                "postgresql+psycopg",
                 "postgresql+psycopg2",
                 "postgresql+psycopg2cffi",
                 "mysql+mysqlconnector",
@@ -1161,7 +1162,10 @@ class DefaultRequirements(SuiteRequirements):
     @property
     def infinity_floats(self):
         return fails_on_everything_except(
-            "sqlite", "postgresql+psycopg2", "postgresql+asyncpg"
+            "sqlite",
+            "postgresql+psycopg2",
+            "postgresql+asyncpg",
+            "postgresql+psycopg",
         ) + skip_if(
             "postgresql+pg8000", "seems to work on pg14 only, not earlier?"
         )
@@ -1241,9 +1245,7 @@ class DefaultRequirements(SuiteRequirements):
     @property
     def range_types(self):
         def check_range_types(config):
-            if not against(
-                config, ["postgresql+psycopg2", "postgresql+psycopg2cffi"]
-            ):
+            if not self.psycopg_compatibility.enabled:
                 return False
             try:
                 with config.db.connect() as conn:
@@ -1291,23 +1293,27 @@ class DefaultRequirements(SuiteRequirements):
         )
 
     @property
-    def psycopg2_native_hstore(self):
-        return self.psycopg2_compatibility
+    def native_hstore(self):
+        return self.psycopg_compatibility
 
     @property
     def psycopg2_compatibility(self):
         return only_on(["postgresql+psycopg2", "postgresql+psycopg2cffi"])
 
     @property
-    def psycopg2_or_pg8000_compatibility(self):
+    def psycopg_compatibility(self):
         return only_on(
             [
                 "postgresql+psycopg2",
                 "postgresql+psycopg2cffi",
-                "postgresql+pg8000",
+                "postgresql+psycopg",
             ]
         )
 
+    @property
+    def psycopg_or_pg8000_compatibility(self):
+        return only_on([self.psycopg_compatibility, "postgresql+pg8000"])
+
     @property
     def percent_schema_names(self):
         return skip_if(
@@ -1690,3 +1696,8 @@ class DefaultRequirements(SuiteRequirements):
     def reflect_tables_no_columns(self):
         # so far sqlite, mariadb, mysql don't support this
         return only_on(["postgresql"])
+
+    @property
+    def json_deserializer_binary(self):
+        "indicates if the json_deserializer function is called with bytes"
+        return only_on(["postgresql+psycopg"])
index f63e1e01bc499b38c993fe90a8ce5f78a4f9d116..35467de9417cc6d9649b75a25a52979e3a3107ba 100644 (file)
@@ -2459,15 +2459,16 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
     @testing.combinations(
         (True, "omit_alias"), (False, "with_alias"), id_="ai", argnames="omit"
     )
-    @testing.provide_metadata
     @testing.skip_if("mysql < 8")
-    def test_duplicate_values_accepted(self, native, omit):
+    def test_duplicate_values_accepted(
+        self, metadata, connection, native, omit
+    ):
         foo_enum = pep435_enum("foo_enum")
         foo_enum("one", 1, "two")
         foo_enum("three", 3, "four")
         tbl = sa.Table(
             "foo_table",
-            self.metadata,
+            metadata,
             sa.Column("id", sa.Integer),
             sa.Column(
                 "data",
@@ -2481,7 +2482,7 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
         )
         t = sa.table("foo_table", sa.column("id"), sa.column("data"))
 
-        self.metadata.create_all(testing.db)
+        metadata.create_all(connection)
         if omit:
             with expect_raises(
                 (
@@ -2491,29 +2492,27 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
                     exc.DBAPIError,
                 )
             ):
-                with testing.db.begin() as conn:
-                    conn.execute(
-                        t.insert(),
-                        [
-                            {"id": 1, "data": "four"},
-                            {"id": 2, "data": "three"},
-                        ],
-                    )
-        else:
-            with testing.db.begin() as conn:
-                conn.execute(
+                connection.execute(
                     t.insert(),
-                    [{"id": 1, "data": "four"}, {"id": 2, "data": "three"}],
+                    [
+                        {"id": 1, "data": "four"},
+                        {"id": 2, "data": "three"},
+                    ],
                 )
+        else:
+            connection.execute(
+                t.insert(),
+                [{"id": 1, "data": "four"}, {"id": 2, "data": "three"}],
+            )
 
-                eq_(
-                    conn.execute(t.select().order_by(t.c.id)).fetchall(),
-                    [(1, "four"), (2, "three")],
-                )
-                eq_(
-                    conn.execute(tbl.select().order_by(tbl.c.id)).fetchall(),
-                    [(1, foo_enum.three), (2, foo_enum.three)],
-                )
+            eq_(
+                connection.execute(t.select().order_by(t.c.id)).fetchall(),
+                [(1, "four"), (2, "three")],
+            )
+            eq_(
+                connection.execute(tbl.select().order_by(tbl.c.id)).fetchall(),
+                [(1, foo_enum.three), (2, foo_enum.three)],
+            )
 
 
 MyPickleType = None
diff --git a/tox.ini b/tox.ini
index d8ba67a440c0c9f2019adc7c65ae567dcad56df0..9f6ccb0c65f735af17f291e187d284747e3c90da 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -27,6 +27,7 @@ deps=
      postgresql: .[postgresql]
      postgresql: .[postgresql_asyncpg]; python_version >= '3'
      postgresql: .[postgresql_pg8000]; python_version >= '3'
+     postgresql: .[postgresql_psycopg]; python_version >= '3'
 
      mysql: .[mysql]
      mysql: .[pymysql]
@@ -43,6 +44,8 @@ deps=
      dbapimain-postgresql: git+https://github.com/psycopg/psycopg2.git#egg=psycopg2
      dbapimain-postgresql: git+https://github.com/MagicStack/asyncpg.git#egg=asyncpg
      dbapimain-postgresql: git+https://github.com/tlocke/pg8000.git#egg=pg8000
+     dbapimain-postgresql: git+https://github.com/psycopg/psycopg.git#egg=psycopg&subdirectory=psycopg
+     dbapimain-postgresql: git+https://github.com/psycopg/psycopg.git#egg=psycopg-c&subdirectory=psycopg_c
 
      dbapimain-mysql: git+https://github.com/PyMySQL/mysqlclient-python.git#egg=mysqlclient
      dbapimain-mysql: git+https://github.com/PyMySQL/PyMySQL.git#egg=pymysql
@@ -89,14 +92,13 @@ setenv=
 
     sqlite: SQLITE={env:TOX_SQLITE:--db sqlite}
     sqlite_file: SQLITE={env:TOX_SQLITE_FILE:--db sqlite_file}
-    py3{,5,6,7,8,9,10,11}-sqlite: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver aiosqlite}
-    py3{,5,6,7,8,9}-sqlite_file: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver aiosqlite --dbdriver pysqlcipher}
+    py3{,7,8,9,10,11}-sqlite: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver aiosqlite}
+    py3{,7,8,9}-sqlite_file: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver aiosqlite --dbdriver pysqlcipher}
     # omit pysqlcipher for Python 3.10
     py3{,10,11}-sqlite_file: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver aiosqlite}
 
     postgresql: POSTGRESQL={env:TOX_POSTGRESQL:--db postgresql}
-    py2{,7}-postgresql: POSTGRESQL={env:TOX_POSTGRESQL_PY2K:{env:TOX_POSTGRESQL:--db postgresql}}
-    py3{,5,6,7,8,9,10,11}-postgresql: EXTRA_PG_DRIVERS={env:EXTRA_PG_DRIVERS:--dbdriver psycopg2 --dbdriver asyncpg --dbdriver pg8000}
+    py3{,7,8,9,10,11}-postgresql: EXTRA_PG_DRIVERS={env:EXTRA_PG_DRIVERS:--dbdriver psycopg2 --dbdriver asyncpg --dbdriver pg8000 --dbdriver psycopg --dbdriver psycopg_async}
 
     mysql: MYSQL={env:TOX_MYSQL:--db mysql}
     py2{,7}-mysql: MYSQL={env:TOX_MYSQL_PY2K:{env:TOX_MYSQL:--db mysql}}