From 5eb407f84bdabdbcd68975dbf76dc4c0809d7373 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Tue, 14 Sep 2021 23:38:00 +0200 Subject: [PATCH] Added support for ``psycopg`` dialect. Both sync and async versions are supported. Fixes: #6842 Change-Id: I57751c5028acebfc6f9c43572562405453a2f2a4 --- doc/build/changelog/migration_20.rst | 25 + doc/build/changelog/unreleased_14/6842.rst | 14 + doc/build/dialects/postgresql.rst | 7 + lib/sqlalchemy/dialects/mysql/asyncmy.py | 2 +- .../dialects/postgresql/__init__.py | 8 + .../dialects/postgresql/_psycopg_common.py | 188 +++++ lib/sqlalchemy/dialects/postgresql/asyncpg.py | 1 - lib/sqlalchemy/dialects/postgresql/base.py | 8 +- lib/sqlalchemy/dialects/postgresql/psycopg.py | 641 ++++++++++++++++++ .../dialects/postgresql/psycopg2.py | 163 +---- lib/sqlalchemy/engine/create.py | 6 +- lib/sqlalchemy/engine/cursor.py | 1 + lib/sqlalchemy/engine/default.py | 12 + lib/sqlalchemy/engine/interfaces.py | 19 + lib/sqlalchemy/engine/url.py | 7 +- lib/sqlalchemy/ext/asyncio/engine.py | 1 + lib/sqlalchemy/sql/sqltypes.py | 42 +- lib/sqlalchemy/testing/config.py | 5 +- lib/sqlalchemy/testing/plugin/plugin_base.py | 3 +- lib/sqlalchemy/testing/requirements.py | 10 + lib/sqlalchemy/testing/suite/test_results.py | 2 + lib/sqlalchemy/testing/suite/test_types.py | 8 +- setup.cfg | 5 + test/dialect/postgresql/test_dialect.py | 134 +++- test/dialect/postgresql/test_query.py | 161 ++++- test/dialect/postgresql/test_types.py | 30 +- test/engine/test_parseconnect.py | 58 ++ test/engine/test_reconnect.py | 3 +- test/ext/asyncio/test_engine_py3k.py | 12 +- test/requirements.py | 27 +- test/sql/test_types.py | 47 +- tox.ini | 10 +- 32 files changed, 1395 insertions(+), 265 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/6842.rst create mode 100644 lib/sqlalchemy/dialects/postgresql/_psycopg_common.py create mode 100644 lib/sqlalchemy/dialects/postgresql/psycopg.py diff --git a/doc/build/changelog/migration_20.rst b/doc/build/changelog/migration_20.rst index b75cefb31e..1c01698886 100644 --- a/doc/build/changelog/migration_20.rst +++ b/doc/build/changelog/migration_20.rst @@ -77,6 +77,31 @@ New Features and Improvements This section covers new features and improvements in SQLAlchemy 2.0 which are not otherwise part of the major 1.4->2.0 migration path. +.. _ticket_6842: + +Dialect support for psycopg 3 (a.k.a. "psycopg") +------------------------------------------------- + +Added dialect support for the `psycopg 3 `_ +DBAPI, which despite the number "3" now goes by the package name ``psycopg``, +superseding the previous ``psycopg2`` package that for the time being remains +SQLAlchemy's "default" driver for the ``postgresql`` dialects. ``psycopg`` is a +completely reworked and modernized database adapter for PostgreSQL which +supports concepts such as prepared statements as well as Python asyncio. + +``psycopg`` is the first DBAPI supported by SQLAlchemy which provides +both a pep-249 synchronous API as well as an asyncio driver. The same +``psycopg`` database URL may be used with the :func:`_sa.create_engine` +and :func:`_asyncio.create_async_engine` engine-creation functions, and the +corresponding sync or asyncio version of the dialect will be selected +automatically. + +.. seealso:: + + :ref:`postgresql_psycopg` + + + Behavioral Changes ================== diff --git a/doc/build/changelog/unreleased_14/6842.rst b/doc/build/changelog/unreleased_14/6842.rst new file mode 100644 index 0000000000..43b3841ce4 --- /dev/null +++ b/doc/build/changelog/unreleased_14/6842.rst @@ -0,0 +1,14 @@ +.. change:: + :tags: postgresql, dialect + :tickets: 6842 + + Added support for ``psycopg`` dialect supporting both sync and async + execution. This dialect is available under the ``postgresql+psycopg`` name + for both the :func:`_sa.create_engine` and + :func:`_asyncio.create_async_engine` engine-creation functions. + + .. seealso:: + + :ref:`ticket_6842` + + :ref:`postgresql_psycopg` \ No newline at end of file diff --git a/doc/build/dialects/postgresql.rst b/doc/build/dialects/postgresql.rst index 958f8e0602..d3c9928c71 100644 --- a/doc/build/dialects/postgresql.rst +++ b/doc/build/dialects/postgresql.rst @@ -193,6 +193,13 @@ psycopg2 .. automodule:: sqlalchemy.dialects.postgresql.psycopg2 +.. _postgresql_psycopg: + +psycopg +-------- + +.. automodule:: sqlalchemy.dialects.postgresql.psycopg + pg8000 ------ diff --git a/lib/sqlalchemy/dialects/mysql/asyncmy.py b/lib/sqlalchemy/dialects/mysql/asyncmy.py index b595714603..8092e99ffb 100644 --- a/lib/sqlalchemy/dialects/mysql/asyncmy.py +++ b/lib/sqlalchemy/dialects/mysql/asyncmy.py @@ -175,7 +175,7 @@ class AsyncAdapt_asyncmy_ss_cursor(AsyncAdapt_asyncmy_cursor): class AsyncAdapt_asyncmy_connection(AdaptedConnection): await_ = staticmethod(await_only) - __slots__ = ("dbapi", "_connection", "_execute_mutex") + __slots__ = ("dbapi", "_execute_mutex") def __init__(self, dbapi, connection): self.dbapi = dbapi diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py index 08b05dc748..b0af5395e7 100644 --- a/lib/sqlalchemy/dialects/postgresql/__init__.py +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -4,9 +4,12 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from types import ModuleType + from . import asyncpg # noqa from . import base from . import pg8000 # noqa +from . import psycopg # noqa from . import psycopg2 # noqa from . import psycopg2cffi # noqa from .array import All @@ -58,6 +61,11 @@ from .ranges import TSRANGE from .ranges import TSTZRANGE from ...util import compat +# Alias psycopg also as psycopg_asnyc +psycopg_async = type( + "psycopg_asnyc", (ModuleType,), {"dialect": psycopg.dialect_async} +) + base.dialect = dialect = psycopg2.dialect diff --git a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py new file mode 100644 index 0000000000..d82d5f0091 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py @@ -0,0 +1,188 @@ +import decimal + +from .array import ARRAY as PGARRAY +from .base import _DECIMAL_TYPES +from .base import _FLOAT_TYPES +from .base import _INT_TYPES +from .base import PGDialect +from .base import PGExecutionContext +from .base import UUID +from .hstore import HSTORE +from ... import exc +from ... import processors +from ... import types as sqltypes +from ... import util + +_server_side_id = util.counter() + + +class _PsycopgNumeric(sqltypes.Numeric): + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect, coltype): + if self.asdecimal: + if coltype in _FLOAT_TYPES: + return processors.to_decimal_processor_factory( + decimal.Decimal, self._effective_decimal_return_scale + ) + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + # psycopg returns Decimal natively for 1700 + return None + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype + ) + else: + if coltype in _FLOAT_TYPES: + # psycopg returns float natively for 701 + return None + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + return processors.to_float + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype + ) + + +class _PsycopgHStore(HSTORE): + def bind_processor(self, dialect): + if dialect._has_native_hstore: + return None + else: + return super(_PsycopgHStore, self).bind_processor(dialect) + + def result_processor(self, dialect, coltype): + if dialect._has_native_hstore: + return None + else: + return super(_PsycopgHStore, self).result_processor( + dialect, coltype + ) + + +class _PsycopgUUID(UUID): + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect, coltype): + if not self.as_uuid and dialect.use_native_uuid: + + def process(value): + if value is not None: + value = str(value) + return value + + return process + + +class _PsycopgARRAY(PGARRAY): + render_bind_cast = True + + +class _PGExecutionContext_common_psycopg(PGExecutionContext): + def create_server_side_cursor(self): + # use server-side cursors: + # psycopg + # https://www.psycopg.org/psycopg3/docs/advanced/cursors.html#server-side-cursors + # psycopg2 + # https://www.psycopg.org/docs/usage.html#server-side-cursors + ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:]) + return self._dbapi_connection.cursor(ident) + + +class _PGDialect_common_psycopg(PGDialect): + supports_statement_cache = True + supports_server_side_cursors = True + + default_paramstyle = "pyformat" + + _has_native_hstore = True + + colspecs = util.update_copy( + PGDialect.colspecs, + { + sqltypes.Numeric: _PsycopgNumeric, + HSTORE: _PsycopgHStore, + UUID: _PsycopgUUID, + sqltypes.ARRAY: _PsycopgARRAY, + }, + ) + + def __init__( + self, + client_encoding=None, + use_native_hstore=True, + use_native_uuid=True, + **kwargs + ): + PGDialect.__init__(self, **kwargs) + if not use_native_hstore: + self._has_native_hstore = False + self.use_native_hstore = use_native_hstore + self.use_native_uuid = use_native_uuid + self.client_encoding = client_encoding + + def create_connect_args(self, url): + opts = url.translate_connect_args(username="user", database="dbname") + + is_multihost = False + if "host" in url.query: + is_multihost = isinstance(url.query["host"], (list, tuple)) + + if opts: + if "port" in opts: + opts["port"] = int(opts["port"]) + opts.update(url.query) + if is_multihost: + opts["host"] = ",".join(url.query["host"]) + # send individual dbname, user, password, host, port + # parameters to psycopg2.connect() + return ([], opts) + elif url.query: + # any other connection arguments, pass directly + opts.update(url.query) + if is_multihost: + opts["host"] = ",".join(url.query["host"]) + return ([], opts) + else: + # no connection arguments whatsoever; psycopg2.connect() + # requires that "dsn" be present as a blank string. + return ([""], opts) + + def get_isolation_level_values(self, dbapi_conn): + return ( + "AUTOCOMMIT", + "READ COMMITTED", + "READ UNCOMMITTED", + "REPEATABLE READ", + "SERIALIZABLE", + ) + + def set_deferrable(self, connection, value): + connection.deferrable = value + + def get_deferrable(self, connection): + return connection.deferrable + + def _do_autocommit(self, connection, value): + connection.autocommit = value + + def do_ping(self, dbapi_connection): + cursor = None + try: + self._do_autocommit(dbapi_connection, True) + cursor = dbapi_connection.cursor() + try: + cursor.execute(self._dialect_specific_select_one) + finally: + cursor.close() + if not dbapi_connection.closed: + self._do_autocommit(dbapi_connection, False) + except self.dbapi.Error as err: + if self.is_disconnect(err, dbapi_connection, cursor): + return False + else: + raise + else: + return True diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 5ef9df8001..1fdb46b6f2 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -553,7 +553,6 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor): class AsyncAdapt_asyncpg_connection(AdaptedConnection): __slots__ = ( "dbapi", - "_connection", "isolation_level", "_isolation_setting", "readonly", diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 0083988653..d1d881dc3d 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1701,7 +1701,7 @@ class UUID(sqltypes.TypeEngine): or as Python uuid objects. The UUID type is currently known to work within the prominent DBAPI - drivers supported by SQLAlchemy including psycopg2, pg8000 and + drivers supported by SQLAlchemy including psycopg, psycopg2, pg8000 and asyncpg. Support for other DBAPI drivers may be incomplete or non-present. """ @@ -1992,6 +1992,12 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): self.connection.execute(DropEnumType(enum)) + def get_dbapi_type(self, dbapi): + """dont return dbapi.STRING for ENUM in PostgreSQL, since that's + a different type""" + + return None + def _check_for_name_in_memos(self, checkfirst, kw): """Look in the 'ddl runner' for 'memos', then note our name in that collection. diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg.py b/lib/sqlalchemy/dialects/postgresql/psycopg.py new file mode 100644 index 0000000000..c2017c9750 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/psycopg.py @@ -0,0 +1,641 @@ +# postgresql/psycopg2.py +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +r""" +.. dialect:: postgresql+psycopg + :name: psycopg (a.k.a. psycopg 3) + :dbapi: psycopg + :connectstring: postgresql+psycopg://user:password@host:port/dbname[?key=value&key=value...] + :url: https://pypi.org/project/psycopg/ + +``psycopg`` is the package and module name for version 3 of the ``psycopg`` +database driver, formerly known as ``psycopg2``. This driver is different +enough from its ``psycopg2`` predecessor that SQLAlchemy supports it +via a totally separate dialect; support for ``psycopg2`` is expected to remain +for as long as that package continues to function for modern Python versions, +and also remains the default dialect for the ``postgresql://`` dialect +series. + +The SQLAlchemy ``psycopg`` dialect provides both a sync and an async +implementation under the same dialect name. The proper version is +selected depending on how the engine is created: + +* calling :func:`_sa.create_engine` with ``postgresql+psycopg://...`` will + automatically select the sync version, e.g.:: + + from sqlalchemy import create_engine + sync_engine = create_engine("postgresql+psycopg://scott:tiger@localhost/test") + +* calling :func:`_asyncio.create_async_engine` with + ``postgresql+psycopg://...`` will automatically select the async version, + e.g.:: + + from sqlalchemy.ext.asyncio import create_async_engine + asyncio_engine = create_async_engine("postgresql+psycopg://scott:tiger@localhost/test") + +The asyncio version of the dialect may also be specified explicitly using the +``psycopg_async`` suffix, as:: + + from sqlalchemy.ext.asyncio import create_async_engine + asyncio_engine = create_async_engine("postgresql+psycopg_async://scott:tiger@localhost/test") + +The ``psycopg`` dialect has the same API features as that of ``psycopg2``, +with the exeption of the "fast executemany" helpers. The "fast executemany" +helpers are expected to be generalized and ported to ``psycopg`` before the final +release of SQLAlchemy 2.0, however. + + +.. seealso:: + + :ref:`postgresql_psycopg2` - The SQLAlchemy ``psycopg`` + dialect shares most of its behavior with the ``psycopg2`` dialect. + Further documentation is available there. + +""" # noqa +import logging +import re + +from ._psycopg_common import _PGDialect_common_psycopg +from ._psycopg_common import _PGExecutionContext_common_psycopg +from ._psycopg_common import _PsycopgUUID +from .base import INTERVAL +from .base import PGCompiler +from .base import PGIdentifierPreparer +from .base import UUID +from .json import JSON +from .json import JSONB +from .json import JSONPathType +from ... import pool +from ... import types as sqltypes +from ... import util +from ...engine import AdaptedConnection +from ...util.concurrency import await_fallback +from ...util.concurrency import await_only + +logger = logging.getLogger("sqlalchemy.dialects.postgresql") + + +class _PGString(sqltypes.String): + render_bind_cast = True + + +class _PGJSON(JSON): + render_bind_cast = True + + def bind_processor(self, dialect): + return self._make_bind_processor(None, dialect._psycopg_Json) + + def result_processor(self, dialect, coltype): + return None + + +class _PGJSONB(JSONB): + render_bind_cast = True + + def bind_processor(self, dialect): + return self._make_bind_processor(None, dialect._psycopg_Jsonb) + + def result_processor(self, dialect, coltype): + return None + + +class _PGJSONIntIndexType(sqltypes.JSON.JSONIntIndexType): + __visit_name__ = "json_int_index" + + render_bind_cast = True + + +class _PGJSONStrIndexType(sqltypes.JSON.JSONStrIndexType): + __visit_name__ = "json_str_index" + + render_bind_cast = True + + +class _PGJSONPathType(JSONPathType): + pass + + +class _PGUUID(_PsycopgUUID): + render_bind_cast = True + + +class _PGInterval(INTERVAL): + render_bind_cast = True + + +class _PGTimeStamp(sqltypes.DateTime): + render_bind_cast = True + + +class _PGDate(sqltypes.Date): + render_bind_cast = True + + +class _PGTime(sqltypes.Time): + render_bind_cast = True + + +class _PGInteger(sqltypes.Integer): + render_bind_cast = True + + +class _PGSmallInteger(sqltypes.SmallInteger): + render_bind_cast = True + + +class _PGNullType(sqltypes.NullType): + render_bind_cast = True + + +class _PGBigInteger(sqltypes.BigInteger): + render_bind_cast = True + + +class _PGBoolean(sqltypes.Boolean): + render_bind_cast = True + + +class PGExecutionContext_psycopg(_PGExecutionContext_common_psycopg): + pass + + +class PGCompiler_psycopg(PGCompiler): + pass + + +class PGIdentifierPreparer_psycopg(PGIdentifierPreparer): + pass + + +def _log_notices(diagnostic): + logger.info("%s: %s", diagnostic.severity, diagnostic.message_primary) + + +class PGDialect_psycopg(_PGDialect_common_psycopg): + driver = "psycopg" + + supports_statement_cache = True + supports_server_side_cursors = True + default_paramstyle = "pyformat" + supports_sane_multi_rowcount = True + + execution_ctx_cls = PGExecutionContext_psycopg + statement_compiler = PGCompiler_psycopg + preparer = PGIdentifierPreparer_psycopg + psycopg_version = (0, 0) + + _has_native_hstore = True + + colspecs = util.update_copy( + _PGDialect_common_psycopg.colspecs, + { + sqltypes.String: _PGString, + JSON: _PGJSON, + sqltypes.JSON: _PGJSON, + JSONB: _PGJSONB, + sqltypes.JSON.JSONPathType: _PGJSONPathType, + sqltypes.JSON.JSONIntIndexType: _PGJSONIntIndexType, + sqltypes.JSON.JSONStrIndexType: _PGJSONStrIndexType, + UUID: _PGUUID, + sqltypes.Interval: _PGInterval, + INTERVAL: _PGInterval, + sqltypes.Date: _PGDate, + sqltypes.DateTime: _PGTimeStamp, + sqltypes.Time: _PGTime, + sqltypes.Integer: _PGInteger, + sqltypes.SmallInteger: _PGSmallInteger, + sqltypes.BigInteger: _PGBigInteger, + }, + ) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + if self.dbapi: + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__) + if m: + self.psycopg_version = tuple( + int(x) for x in m.group(1, 2, 3) if x is not None + ) + + if self.psycopg_version < (3, 0, 2): + raise ImportError( + "psycopg version 3.0.2 or higher is required." + ) + + from psycopg.adapt import AdaptersMap + + self._psycopg_adapters_map = adapters_map = AdaptersMap( + self.dbapi.adapters + ) + + if self._json_deserializer: + from psycopg.types.json import set_json_loads + + set_json_loads(self._json_deserializer, adapters_map) + + if self._json_serializer: + from psycopg.types.json import set_json_dumps + + set_json_dumps(self._json_serializer, adapters_map) + + def create_connect_args(self, url): + # see https://github.com/psycopg/psycopg/issues/83 + cargs, cparams = super().create_connect_args(url) + + cparams["context"] = self._psycopg_adapters_map + if self.client_encoding is not None: + cparams["client_encoding"] = self.client_encoding + return cargs, cparams + + def _type_info_fetch(self, connection, name): + from psycopg.types import TypeInfo + + return TypeInfo.fetch(connection.connection, name) + + def initialize(self, connection): + super().initialize(connection) + + # PGDialect.initialize() checks server version for <= 8.2 and sets + # this flag to False if so + if not self.full_returning: + self.insert_executemany_returning = False + + # HSTORE can't be registered until we have a connection so that + # we can look up its OID, so we set up this adapter in + # initialize() + if self.use_native_hstore: + info = self._type_info_fetch(connection, "hstore") + self._has_native_hstore = info is not None + if self._has_native_hstore: + from psycopg.types.hstore import register_hstore + + # register the adapter for connections made subsequent to + # this one + register_hstore(info, self._psycopg_adapters_map) + + # register the adapter for this connection + register_hstore(info, connection.connection) + + @classmethod + def dbapi(cls): + import psycopg + + return psycopg + + @classmethod + def get_async_dialect_cls(cls, url): + return PGDialectAsync_psycopg + + @util.memoized_property + def _isolation_lookup(self): + return { + "READ COMMITTED": self.dbapi.IsolationLevel.READ_COMMITTED, + "READ UNCOMMITTED": self.dbapi.IsolationLevel.READ_UNCOMMITTED, + "REPEATABLE READ": self.dbapi.IsolationLevel.REPEATABLE_READ, + "SERIALIZABLE": self.dbapi.IsolationLevel.SERIALIZABLE, + } + + @util.memoized_property + def _psycopg_Json(self): + from psycopg.types import json + + return json.Json + + @util.memoized_property + def _psycopg_Jsonb(self): + from psycopg.types import json + + return json.Jsonb + + @util.memoized_property + def _psycopg_TransactionStatus(self): + from psycopg.pq import TransactionStatus + + return TransactionStatus + + def _do_isolation_level(self, connection, autocommit, isolation_level): + connection.autocommit = autocommit + connection.isolation_level = isolation_level + + def get_isolation_level(self, dbapi_connection): + if hasattr(dbapi_connection, "dbapi_connection"): + dbapi_connection = dbapi_connection.dbapi_connection + + status_before = dbapi_connection.info.transaction_status + value = super().get_isolation_level(dbapi_connection) + + # don't rely on psycopg providing enum symbols, compare with + # eq/ne + if status_before == self._psycopg_TransactionStatus.IDLE: + dbapi_connection.rollback() + return value + + def set_isolation_level(self, connection, level): + connection = getattr(connection, "dbapi_connection", connection) + if level == "AUTOCOMMIT": + self._do_isolation_level( + connection, autocommit=True, isolation_level=None + ) + else: + self._do_isolation_level( + connection, + autocommit=False, + isolation_level=self._isolation_lookup[level], + ) + + def set_readonly(self, connection, value): + connection.read_only = value + + def get_readonly(self, connection): + return connection.read_only + + def on_connect(self): + def notices(conn): + conn.add_notice_handler(_log_notices) + + fns = [notices] + + if self.isolation_level is not None: + + def on_connect(conn): + self.set_isolation_level(conn, self.isolation_level) + + fns.append(on_connect) + + # fns always has the notices function + def on_connect(conn): + for fn in fns: + fn(conn) + + return on_connect + + def is_disconnect(self, e, connection, cursor): + if isinstance(e, self.dbapi.Error) and connection is not None: + if connection.closed or connection.broken: + return True + return False + + def _do_prepared_twophase(self, connection, command, recover=False): + dbapi_conn = connection.connection.dbapi_connection + if ( + recover + # don't rely on psycopg providing enum symbols, compare with + # eq/ne + or dbapi_conn.info.transaction_status + != self._psycopg_TransactionStatus.IDLE + ): + dbapi_conn.rollback() + before = dbapi_conn.autocommit + try: + self._do_autocommit(dbapi_conn, True) + dbapi_conn.execute(command) + finally: + self._do_autocommit(dbapi_conn, before) + + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if is_prepared: + self._do_prepared_twophase( + connection, f"ROLLBACK PREPARED '{xid}'", recover=recover + ) + else: + self.do_rollback(connection.connection) + + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if is_prepared: + self._do_prepared_twophase( + connection, f"COMMIT PREPARED '{xid}'", recover=recover + ) + else: + self.do_commit(connection.connection) + + +class AsyncAdapt_psycopg_cursor: + __slots__ = ("_cursor", "await_", "_rows") + + _psycopg_ExecStatus = None + + def __init__(self, cursor, await_) -> None: + self._cursor = cursor + self.await_ = await_ + self._rows = [] + + def __getattr__(self, name): + return getattr(self._cursor, name) + + @property + def arraysize(self): + return self._cursor.arraysize + + @arraysize.setter + def arraysize(self, value): + self._cursor.arraysize = value + + def close(self): + self._rows.clear() + # Normal cursor just call _close() in a non-sync way. + self._cursor._close() + + def execute(self, query, params=None, **kw): + result = self.await_(self._cursor.execute(query, params, **kw)) + # sqlalchemy result is not async, so need to pull all rows here + res = self._cursor.pgresult + + # don't rely on psycopg providing enum symbols, compare with + # eq/ne + if res and res.status == self._psycopg_ExecStatus.TUPLES_OK: + rows = self.await_(self._cursor.fetchall()) + if not isinstance(rows, list): + self._rows = list(rows) + else: + self._rows = rows + return result + + def executemany(self, query, params_seq): + return self.await_(self._cursor.executemany(query, params_seq)) + + def __iter__(self): + # TODO: try to avoid pop(0) on a list + while self._rows: + yield self._rows.pop(0) + + def fetchone(self): + if self._rows: + # TODO: try to avoid pop(0) on a list + return self._rows.pop(0) + else: + return None + + def fetchmany(self, size=None): + if size is None: + size = self._cursor.arraysize + + retval = self._rows[0:size] + self._rows = self._rows[size:] + return retval + + def fetchall(self): + retval = self._rows + self._rows = [] + return retval + + +class AsyncAdapt_psycopg_ss_cursor(AsyncAdapt_psycopg_cursor): + def execute(self, query, params=None, **kw): + self.await_(self._cursor.execute(query, params, **kw)) + return self + + def close(self): + self.await_(self._cursor.close()) + + def fetchone(self): + return self.await_(self._cursor.fetchone()) + + def fetchmany(self, size=0): + return self.await_(self._cursor.fetchmany(size)) + + def fetchall(self): + return self.await_(self._cursor.fetchall()) + + def __iter__(self): + iterator = self._cursor.__aiter__() + while True: + try: + yield self.await_(iterator.__anext__()) + except StopAsyncIteration: + break + + +class AsyncAdapt_psycopg_connection(AdaptedConnection): + __slots__ = () + await_ = staticmethod(await_only) + + def __init__(self, connection) -> None: + self._connection = connection + + def __getattr__(self, name): + return getattr(self._connection, name) + + def execute(self, query, params=None, **kw): + cursor = self.await_(self._connection.execute(query, params, **kw)) + return AsyncAdapt_psycopg_cursor(cursor, self.await_) + + def cursor(self, *args, **kw): + cursor = self._connection.cursor(*args, **kw) + if hasattr(cursor, "name"): + return AsyncAdapt_psycopg_ss_cursor(cursor, self.await_) + else: + return AsyncAdapt_psycopg_cursor(cursor, self.await_) + + def commit(self): + self.await_(self._connection.commit()) + + def rollback(self): + self.await_(self._connection.rollback()) + + def close(self): + self.await_(self._connection.close()) + + @property + def autocommit(self): + return self._connection.autocommit + + @autocommit.setter + def autocommit(self, value): + self.set_autocommit(value) + + def set_autocommit(self, value): + self.await_(self._connection.set_autocommit(value)) + + def set_isolation_level(self, value): + self.await_(self._connection.set_isolation_level(value)) + + def set_read_only(self, value): + self.await_(self._connection.set_read_only(value)) + + def set_deferrable(self, value): + self.await_(self._connection.set_deferrable(value)) + + +class AsyncAdaptFallback_psycopg_connection(AsyncAdapt_psycopg_connection): + __slots__ = () + await_ = staticmethod(await_fallback) + + +class PsycopgAdaptDBAPI: + def __init__(self, psycopg) -> None: + self.psycopg = psycopg + + for k, v in self.psycopg.__dict__.items(): + if k != "connect": + self.__dict__[k] = v + + def connect(self, *arg, **kw): + async_fallback = kw.pop("async_fallback", False) + if util.asbool(async_fallback): + return AsyncAdaptFallback_psycopg_connection( + await_fallback( + self.psycopg.AsyncConnection.connect(*arg, **kw) + ) + ) + else: + return AsyncAdapt_psycopg_connection( + await_only(self.psycopg.AsyncConnection.connect(*arg, **kw)) + ) + + +class PGDialectAsync_psycopg(PGDialect_psycopg): + is_async = True + supports_statement_cache = True + + @classmethod + def dbapi(cls): + import psycopg + from psycopg.pq import ExecStatus + + AsyncAdapt_psycopg_cursor._psycopg_ExecStatus = ExecStatus + + return PsycopgAdaptDBAPI(psycopg) + + @classmethod + def get_pool_class(cls, url): + + async_fallback = url.query.get("async_fallback", False) + + if util.asbool(async_fallback): + return pool.FallbackAsyncAdaptedQueuePool + else: + return pool.AsyncAdaptedQueuePool + + def _type_info_fetch(self, connection, name): + from psycopg.types import TypeInfo + + adapted = connection.connection + return adapted.await_(TypeInfo.fetch(adapted._connection, name)) + + def _do_isolation_level(self, connection, autocommit, isolation_level): + connection.set_autocommit(autocommit) + connection.set_isolation_level(isolation_level) + + def _do_autocommit(self, connection, value): + connection.set_autocommit(value) + + def set_readonly(self, connection, value): + connection.set_read_only(value) + + def set_deferrable(self, connection, value): + connection.set_deferrable(value) + + def get_driver_connection(self, connection): + return connection._connection + + +dialect = PGDialect_psycopg +dialect_async = PGDialectAsync_psycopg diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index 39b6e0ed1e..3d9f90a297 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -11,6 +11,8 @@ r""" :connectstring: postgresql+psycopg2://user:password@host:port/dbname[?key=value&key=value...] :url: https://pypi.org/project/psycopg2/ +.. _psycopg2_toplevel: + psycopg2 Connect Arguments -------------------------- @@ -442,25 +444,15 @@ which may be more performant. """ # noqa import collections.abc as collections_abc -import decimal import logging import re -from uuid import UUID as _python_UUID -from .array import ARRAY as PGARRAY -from .base import _DECIMAL_TYPES -from .base import _FLOAT_TYPES -from .base import _INT_TYPES +from ._psycopg_common import _PGDialect_common_psycopg +from ._psycopg_common import _PGExecutionContext_common_psycopg from .base import PGCompiler -from .base import PGDialect -from .base import PGExecutionContext from .base import PGIdentifierPreparer -from .base import UUID -from .hstore import HSTORE from .json import JSON from .json import JSONB -from ... import exc -from ... import processors from ... import types as sqltypes from ... import util from ...engine import cursor as _cursor @@ -469,53 +461,6 @@ from ...engine import cursor as _cursor logger = logging.getLogger("sqlalchemy.dialects.postgresql") -class _PGNumeric(sqltypes.Numeric): - def bind_processor(self, dialect): - return None - - def result_processor(self, dialect, coltype): - if self.asdecimal: - if coltype in _FLOAT_TYPES: - return processors.to_decimal_processor_factory( - decimal.Decimal, self._effective_decimal_return_scale - ) - elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: - # pg8000 returns Decimal natively for 1700 - return None - else: - raise exc.InvalidRequestError( - "Unknown PG numeric type: %d" % coltype - ) - else: - if coltype in _FLOAT_TYPES: - # pg8000 returns float natively for 701 - return None - elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: - return processors.to_float - else: - raise exc.InvalidRequestError( - "Unknown PG numeric type: %d" % coltype - ) - - -class _PGHStore(HSTORE): - def bind_processor(self, dialect): - if dialect._has_native_hstore: - return None - else: - return super(_PGHStore, self).bind_processor(dialect) - - def result_processor(self, dialect, coltype): - if dialect._has_native_hstore: - return None - else: - return super(_PGHStore, self).result_processor(dialect, coltype) - - -class _PGARRAY(PGARRAY): - render_bind_cast = True - - class _PGJSON(JSON): def result_processor(self, dialect, coltype): return None @@ -526,40 +471,9 @@ class _PGJSONB(JSONB): return None -class _PGUUID(UUID): - def bind_processor(self, dialect): - if not self.as_uuid and dialect.use_native_uuid: - - def process(value): - if value is not None: - value = _python_UUID(value) - return value - - return process - - def result_processor(self, dialect, coltype): - if not self.as_uuid and dialect.use_native_uuid: - - def process(value): - if value is not None: - value = str(value) - return value - - return process - - -_server_side_id = util.counter() - - -class PGExecutionContext_psycopg2(PGExecutionContext): +class PGExecutionContext_psycopg2(_PGExecutionContext_common_psycopg): _psycopg2_fetched_rows = None - def create_server_side_cursor(self): - # use server-side cursors: - # https://lists.initd.org/pipermail/psycopg/2007-January/005251.html - ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:]) - return self._dbapi_connection.cursor(ident) - def post_exec(self): if ( self._psycopg2_fetched_rows @@ -614,7 +528,7 @@ EXECUTEMANY_VALUES_PLUS_BATCH = util.symbol( ) -class PGDialect_psycopg2(PGDialect): +class PGDialect_psycopg2(_PGDialect_common_psycopg): driver = "psycopg2" supports_statement_cache = True @@ -631,34 +545,22 @@ class PGDialect_psycopg2(PGDialect): _has_native_hstore = True colspecs = util.update_copy( - PGDialect.colspecs, + _PGDialect_common_psycopg.colspecs, { - sqltypes.Numeric: _PGNumeric, - HSTORE: _PGHStore, JSON: _PGJSON, sqltypes.JSON: _PGJSON, JSONB: _PGJSONB, - UUID: _PGUUID, - sqltypes.ARRAY: _PGARRAY, }, ) def __init__( self, - client_encoding=None, - use_native_hstore=True, - use_native_uuid=True, executemany_mode="values_only", executemany_batch_page_size=100, executemany_values_page_size=1000, **kwargs ): - PGDialect.__init__(self, **kwargs) - if not use_native_hstore: - self._has_native_hstore = False - self.use_native_hstore = use_native_hstore - self.use_native_uuid = use_native_uuid - self.client_encoding = client_encoding + _PGDialect_common_psycopg.__init__(self, **kwargs) # Parse executemany_mode argument, allowing it to be only one of the # symbol names @@ -737,9 +639,6 @@ class PGDialect_psycopg2(PGDialect): "SERIALIZABLE": extensions.ISOLATION_LEVEL_SERIALIZABLE, } - def get_isolation_level_values(self, dbapi_conn): - return list(self._isolation_lookup) - def set_isolation_level(self, connection, level): connection.set_isolation_level(self._isolation_lookup[level]) @@ -755,25 +654,6 @@ class PGDialect_psycopg2(PGDialect): def get_deferrable(self, connection): return connection.deferrable - def do_ping(self, dbapi_connection): - cursor = None - try: - dbapi_connection.autocommit = True - cursor = dbapi_connection.cursor() - try: - cursor.execute(self._dialect_specific_select_one) - finally: - cursor.close() - if not dbapi_connection.closed: - dbapi_connection.autocommit = False - except self.dbapi.Error as err: - if self.is_disconnect(err, dbapi_connection, cursor): - return False - else: - raise - else: - return True - def on_connect(self): extras = self._psycopg2_extras @@ -911,33 +791,6 @@ class PGDialect_psycopg2(PGDialect): else: return None - def create_connect_args(self, url): - opts = url.translate_connect_args(username="user") - - is_multihost = False - if "host" in url.query: - is_multihost = isinstance(url.query["host"], (list, tuple)) - - if opts: - if "port" in opts: - opts["port"] = int(opts["port"]) - opts.update(url.query) - if is_multihost: - opts["host"] = ",".join(url.query["host"]) - # send individual dbname, user, password, host, port - # parameters to psycopg2.connect() - return ([], opts) - elif url.query: - # any other connection arguments, pass directly - opts.update(url.query) - if is_multihost: - opts["host"] = ",".join(url.query["host"]) - return ([], opts) - else: - # no connection arguments whatsoever; psycopg2.connect() - # requires that "dsn" be present as a blank string. - return ([""], opts) - def is_disconnect(self, e, connection, cursor): if isinstance(e, self.dbapi.Error): # check the "closed" flag. this might not be diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py index f9a65a0f8d..8fcba75038 100644 --- a/lib/sqlalchemy/engine/create.py +++ b/lib/sqlalchemy/engine/create.py @@ -458,7 +458,11 @@ def create_engine(url, **kwargs): u, plugins, kwargs = u._instantiate_plugins(kwargs) entrypoint = u._get_entrypoint() - dialect_cls = entrypoint.get_dialect_cls(u) + _is_async = kwargs.pop("_is_async", False) + if _is_async: + dialect_cls = entrypoint.get_async_dialect_cls(u) + else: + dialect_cls = entrypoint.get_dialect_cls(u) if kwargs.pop("_coerce_config", False): diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index 1f1a2fcf1e..7f2b8b412c 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -1223,6 +1223,7 @@ class BaseCursorResult: """ + if (not hard and self._soft_closed) or (hard and self.closed): return diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index cb04eb525b..64500b41ba 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -1284,6 +1284,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): result.out_parameters = out_parameters def _setup_dml_or_text_result(self): + if self.isinsert: if self.compiled.postfetch_lastrowid: self.inserted_primary_key_rows = ( @@ -1332,8 +1333,17 @@ class DefaultExecutionContext(interfaces.ExecutionContext): # assert not result.returns_rows elif self.isupdate and self._is_implicit_returning: + # get rowcount + # (which requires open cursor on some drivers) + # we were not doing this in 1.4, however + # test_rowcount -> test_update_rowcount_return_defaults + # is testing this, and psycopg will no longer return + # rowcount after cursor is closed. + result.rowcount + row = result.fetchone() self.returned_default_rows = [row] + result._soft_close() # test that it has a cursor metadata that is accurate. @@ -1410,6 +1420,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext): dialect = self.dialect + # all of the rest of this... cython? + if dialect._has_events: inputsizes = dict(inputsizes) dialect.dispatch.do_setinputsizes( diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 251d01c5e2..faaf073ab0 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -1113,6 +1113,25 @@ class Dialect: """ return cls + @classmethod + def get_async_dialect_cls(cls, url): + """Given a URL, return the :class:`.Dialect` that will be used by + an async engine. + + By default this is an alias of :meth:`.Dialect.get_dialect_cls` and + just returns the cls. It may be used if a dialect provides + both a sync and async version under the same name, like the + ``psycopg`` driver. + + .. versionadded:: 2 + + .. seealso:: + + :meth:`.Dialect.get_dialect_cls` + + """ + return cls.get_dialect_cls(url) + @classmethod def load_provisioning(cls): """set up the provision.py module for this dialect. diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py index c83753bdc9..7cdf25c21a 100644 --- a/lib/sqlalchemy/engine/url.py +++ b/lib/sqlalchemy/engine/url.py @@ -655,13 +655,16 @@ class URL( else: return cls - def get_dialect(self): + def get_dialect(self, _is_async=False): """Return the SQLAlchemy :class:`_engine.Dialect` class corresponding to this URL's driver name. """ entrypoint = self._get_entrypoint() - dialect_cls = entrypoint.get_dialect_cls(self) + if _is_async: + dialect_cls = entrypoint.get_async_dialect_cls(self) + else: + dialect_cls = entrypoint.get_dialect_cls(self) return dialect_cls def translate_connect_args(self, names=None, **kw): diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 221d82f088..67d944b9ce 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -37,6 +37,7 @@ def create_async_engine(*arg, **kw): "streaming result set" ) kw["future"] = True + kw["_is_async"] = True sync_engine = _create_engine(*arg, **kw) return AsyncEngine(sync_engine) diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 8874f0a83d..427251a885 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -2338,26 +2338,40 @@ class JSON(Indexable, TypeEngine): def _str_impl(self): return String() - def bind_processor(self, dialect): - string_process = self._str_impl.bind_processor(dialect) + def _make_bind_processor(self, string_process, json_serializer): + if string_process: - json_serializer = dialect._json_serializer or json.dumps + def process(value): + if value is self.NULL: + value = None + elif isinstance(value, elements.Null) or ( + value is None and self.none_as_null + ): + return None - def process(value): - if value is self.NULL: - value = None - elif isinstance(value, elements.Null) or ( - value is None and self.none_as_null - ): - return None + serialized = json_serializer(value) + return string_process(serialized) - serialized = json_serializer(value) - if string_process: - serialized = string_process(serialized) - return serialized + else: + + def process(value): + if value is self.NULL: + value = None + elif isinstance(value, elements.Null) or ( + value is None and self.none_as_null + ): + return None + + return json_serializer(value) return process + def bind_processor(self, dialect): + string_process = self._str_impl.bind_processor(dialect) + json_serializer = dialect._json_serializer or json.dumps + + return self._make_bind_processor(string_process, json_serializer) + def result_processor(self, dialect, coltype): string_process = self._str_impl.result_processor(dialect, coltype) json_deserializer = dialect._json_deserializer or json.loads diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index 8faeea6341..22d9c523a2 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -124,11 +124,12 @@ class Config: _configs = set() def _set_name(self, db): + suffix = "_async" if db.dialect.is_async else "" if db.dialect.server_version_info: svi = ".".join(str(tok) for tok in db.dialect.server_version_info) - self.name = "%s+%s_[%s]" % (db.name, db.driver, svi) + self.name = "%s+%s%s_[%s]" % (db.name, db.driver, suffix, svi) else: - self.name = "%s+%s" % (db.name, db.driver) + self.name = "%s+%s%s" % (db.name, db.driver, suffix) @classmethod def register(cls, db, db_opts, options, file_config): diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index 32ed2c3159..d79931b91e 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -706,7 +706,8 @@ def _do_skips(cls): ) if not all_configs: - msg = "'%s' unsupported on any DB implementation %s%s" % ( + msg = "'%s.%s' unsupported on any DB implementation %s%s" % ( + cls.__module__, cls.__name__, ", ".join( "'%s(%s)+%s'" diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index 8cb72d1633..f811be6573 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -1493,3 +1493,13 @@ class SuiteRequirements(Requirements): sequence. This should be false only for oracle. """ return exclusions.open() + + @property + def generic_classes(self): + "If X[Y] can be implemented with ``__class_getitem__``. py3.7+" + return exclusions.open() + + @property + def json_deserializer_binary(self): + "indicates if the json_deserializer function is called with bytes" + return exclusions.closed() diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py index 5ad68034b9..a8900ece14 100644 --- a/lib/sqlalchemy/testing/suite/test_results.py +++ b/lib/sqlalchemy/testing/suite/test_results.py @@ -244,6 +244,8 @@ class ServerSideCursorsTest( return cursor.server_side elif self.engine.dialect.driver == "pg8000": return getattr(cursor, "server_side", False) + elif self.engine.dialect.driver == "psycopg": + return bool(getattr(cursor, "name", False)) else: return False diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index 82e6fa2383..e7131ec6ea 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -1107,7 +1107,13 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): eq_(row, (data_element,)) eq_(js.mock_calls, [mock.call(data_element)]) - eq_(jd.mock_calls, [mock.call(json.dumps(data_element))]) + if testing.requires.json_deserializer_binary.enabled: + eq_( + jd.mock_calls, + [mock.call(json.dumps(data_element).encode())], + ) + else: + eq_(jd.mock_calls, [mock.call(json.dumps(data_element))]) @testing.combinations( ("parameters",), diff --git a/setup.cfg b/setup.cfg index 624ace1bdf..80477f8a4e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -66,6 +66,7 @@ postgresql_asyncpg = asyncpg;python_version>="3" postgresql_psycopg2binary = psycopg2-binary postgresql_psycopg2cffi = psycopg2cffi +postgresql_psycopg = psycopg>=3.0.2 pymysql = pymysql;python_version>="3" pymysql<1;python_version<"3" @@ -157,6 +158,10 @@ sqlite_file = sqlite:///querytest.db aiosqlite_file = sqlite+aiosqlite:///async_querytest.db pysqlcipher_file = sqlite+pysqlcipher://:test@/querytest.db.enc postgresql = postgresql+psycopg2://scott:tiger@127.0.0.1:5432/test +psycopg2 = postgresql+psycopg2://scott:tiger@127.0.0.1:5432/test +psycopg = postgresql+psycopg://scott:tiger@127.0.0.1:5432/test +psycopg_async = postgresql+psycopg_async://scott:tiger@127.0.0.1:5432/test +psycopg_async_fallback = postgresql+psycopg_async://scott:tiger@127.0.0.1:5432/test?async_fallback=true asyncpg = postgresql+asyncpg://scott:tiger@127.0.0.1:5432/test asyncpg_fallback = postgresql+asyncpg://scott:tiger@127.0.0.1:5432/test?async_fallback=true pg8000 = postgresql+pg8000://scott:tiger@127.0.0.1:5432/test diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index 57682686c8..02d7ad483e 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -8,6 +8,7 @@ from sqlalchemy import BigInteger from sqlalchemy import bindparam from sqlalchemy import cast from sqlalchemy import Column +from sqlalchemy import create_engine from sqlalchemy import DateTime from sqlalchemy import DDL from sqlalchemy import event @@ -30,6 +31,9 @@ from sqlalchemy import text from sqlalchemy import TypeDecorator from sqlalchemy import util from sqlalchemy.dialects.postgresql import base as postgresql +from sqlalchemy.dialects.postgresql import HSTORE +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.dialects.postgresql import psycopg as psycopg_dialect from sqlalchemy.dialects.postgresql import psycopg2 as psycopg2_dialect from sqlalchemy.dialects.postgresql.psycopg2 import EXECUTEMANY_BATCH from sqlalchemy.dialects.postgresql.psycopg2 import EXECUTEMANY_PLAIN @@ -269,10 +273,12 @@ class PGCodeTest(fixtures.TestBase): if testing.against("postgresql+pg8000"): # TODO: is there another way we're supposed to see this? eq_(errmsg.orig.args[0]["C"], "23505") - else: + elif not testing.against("postgresql+psycopg"): eq_(errmsg.orig.pgcode, "23505") - if testing.against("postgresql+asyncpg"): + if testing.against("postgresql+asyncpg") or testing.against( + "postgresql+psycopg" + ): eq_(errmsg.orig.sqlstate, "23505") @@ -858,6 +864,13 @@ class MiscBackendTest( ".".join(str(x) for x in v) ) + @testing.only_on("postgresql+psycopg") + def test_psycopg_version(self): + v = testing.db.dialect.psycopg_version + assert testing.db.dialect.dbapi.__version__.startswith( + ".".join(str(x) for x in v) + ) + @testing.combinations( ((8, 1), False, False), ((8, 1), None, False), @@ -902,6 +915,7 @@ class MiscBackendTest( with testing.db.connect().execution_options( isolation_level="SERIALIZABLE" ) as conn: + dbapi_conn = conn.connection.dbapi_connection is_false(dbapi_conn.autocommit) @@ -1069,25 +1083,30 @@ class MiscBackendTest( dbapi_conn.rollback() eq_(val, "off") - @testing.requires.psycopg2_compatibility - def test_psycopg2_non_standard_err(self): + @testing.requires.psycopg_compatibility + def test_psycopg_non_standard_err(self): # note that psycopg2 is sometimes called psycopg2cffi # depending on platform - psycopg2 = testing.db.dialect.dbapi - TransactionRollbackError = __import__( - "%s.extensions" % psycopg2.__name__ - ).extensions.TransactionRollbackError + psycopg = testing.db.dialect.dbapi + if psycopg.__version__.startswith("3"): + TransactionRollbackError = __import__( + "%s.errors" % psycopg.__name__ + ).errors.TransactionRollback + else: + TransactionRollbackError = __import__( + "%s.extensions" % psycopg.__name__ + ).extensions.TransactionRollbackError exception = exc.DBAPIError.instance( "some statement", {}, TransactionRollbackError("foo"), - psycopg2.Error, + psycopg.Error, ) assert isinstance(exception, exc.OperationalError) @testing.requires.no_coverage - @testing.requires.psycopg2_compatibility + @testing.requires.psycopg_compatibility def test_notice_logging(self): log = logging.getLogger("sqlalchemy.dialects.postgresql") buf = logging.handlers.BufferingHandler(100) @@ -1115,14 +1134,14 @@ $$ LANGUAGE plpgsql; finally: log.removeHandler(buf) log.setLevel(lev) - msgs = " ".join(b.msg for b in buf.buffer) + msgs = " ".join(b.getMessage() for b in buf.buffer) eq_regex( msgs, - "NOTICE: notice: hi there(\nCONTEXT: .*?)? " - "NOTICE: notice: another note(\nCONTEXT: .*?)?", + "NOTICE: [ ]?notice: hi there(\nCONTEXT: .*?)? " + "NOTICE: [ ]?notice: another note(\nCONTEXT: .*?)?", ) - @testing.requires.psycopg2_or_pg8000_compatibility + @testing.requires.psycopg_or_pg8000_compatibility @engines.close_open_connections def test_client_encoding(self): c = testing.db.connect() @@ -1143,7 +1162,7 @@ $$ LANGUAGE plpgsql; new_encoding = c.exec_driver_sql("show client_encoding").fetchone()[0] eq_(new_encoding, test_encoding) - @testing.requires.psycopg2_or_pg8000_compatibility + @testing.requires.psycopg_or_pg8000_compatibility @engines.close_open_connections def test_autocommit_isolation_level(self): c = testing.db.connect().execution_options( @@ -1302,7 +1321,7 @@ $$ LANGUAGE plpgsql; assert result == [(1, "user", "lala")] connection.execute(text("DROP TABLE speedy_users")) - @testing.requires.psycopg2_or_pg8000_compatibility + @testing.requires.psycopg_or_pg8000_compatibility def test_numeric_raise(self, connection): stmt = text("select cast('hi' as char) as hi").columns(hi=Numeric) assert_raises(exc.InvalidRequestError, connection.execute, stmt) @@ -1364,9 +1383,90 @@ $$ LANGUAGE plpgsql; ) @testing.requires.psycopg2_compatibility - def test_initial_transaction_state(self): + def test_initial_transaction_state_psycopg2(self): from psycopg2.extensions import STATUS_IN_TRANSACTION engine = engines.testing_engine() with engine.connect() as conn: ne_(conn.connection.status, STATUS_IN_TRANSACTION) + + @testing.only_on("postgresql+psycopg") + def test_initial_transaction_state_psycopg(self): + from psycopg.pq import TransactionStatus + + engine = engines.testing_engine() + with engine.connect() as conn: + ne_( + conn.connection.dbapi_connection.info.transaction_status, + TransactionStatus.INTRANS, + ) + + +class Psycopg3Test(fixtures.TestBase): + __only_on__ = ("postgresql+psycopg",) + + def test_json_correctly_registered(self, testing_engine): + import json + + def loads(value): + value = json.loads(value) + value["x"] = value["x"] + "_loads" + return value + + def dumps(value): + value = dict(value) + value["x"] = "dumps_y" + return json.dumps(value) + + engine = testing_engine( + options=dict(json_serializer=dumps, json_deserializer=loads) + ) + engine2 = testing_engine( + options=dict( + json_serializer=json.dumps, json_deserializer=json.loads + ) + ) + + s = select(cast({"key": "value", "x": "q"}, JSONB)) + with engine.begin() as conn: + eq_(conn.scalar(s), {"key": "value", "x": "dumps_y_loads"}) + with engine.begin() as conn: + eq_(conn.scalar(s), {"key": "value", "x": "dumps_y_loads"}) + with engine2.begin() as conn: + eq_(conn.scalar(s), {"key": "value", "x": "q"}) + with engine.begin() as conn: + eq_(conn.scalar(s), {"key": "value", "x": "dumps_y_loads"}) + + @testing.requires.hstore + def test_hstore_correctly_registered(self, testing_engine): + engine = testing_engine(options=dict(use_native_hstore=True)) + engine2 = testing_engine(options=dict(use_native_hstore=False)) + + def rp(self, *a): + return lambda a: {"a": "b"} + + with mock.patch.object(HSTORE, "result_processor", side_effect=rp): + s = select(cast({"key": "value", "x": "q"}, HSTORE)) + with engine.begin() as conn: + eq_(conn.scalar(s), {"key": "value", "x": "q"}) + with engine.begin() as conn: + eq_(conn.scalar(s), {"key": "value", "x": "q"}) + with engine2.begin() as conn: + eq_(conn.scalar(s), {"a": "b"}) + with engine.begin() as conn: + eq_(conn.scalar(s), {"key": "value", "x": "q"}) + + def test_get_dialect(self): + u = url.URL.create("postgresql://") + d = psycopg_dialect.PGDialect_psycopg.get_dialect_cls(u) + is_(d, psycopg_dialect.PGDialect_psycopg) + d = psycopg_dialect.PGDialect_psycopg.get_async_dialect_cls(u) + is_(d, psycopg_dialect.PGDialectAsync_psycopg) + d = psycopg_dialect.PGDialectAsync_psycopg.get_dialect_cls(u) + is_(d, psycopg_dialect.PGDialectAsync_psycopg) + d = psycopg_dialect.PGDialectAsync_psycopg.get_dialect_cls(u) + is_(d, psycopg_dialect.PGDialectAsync_psycopg) + + def test_async_version(self): + e = create_engine("postgresql+psycopg_async://") + is_true(isinstance(e.dialect, psycopg_dialect.PGDialectAsync_psycopg)) diff --git a/test/dialect/postgresql/test_query.py b/test/dialect/postgresql/test_query.py index b488b146cf..fdce643f84 100644 --- a/test/dialect/postgresql/test_query.py +++ b/test/dialect/postgresql/test_query.py @@ -14,6 +14,7 @@ from sqlalchemy import Float from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Integer +from sqlalchemy import JSON from sqlalchemy import literal from sqlalchemy import literal_column from sqlalchemy import MetaData @@ -29,6 +30,7 @@ from sqlalchemy import true from sqlalchemy import tuple_ from sqlalchemy.dialects import postgresql from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.sql.expression import type_coerce from sqlalchemy.testing import assert_raises from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import AssertsExecutionResults @@ -40,6 +42,17 @@ from sqlalchemy.testing.assertsql import CursorSQL from sqlalchemy.testing.assertsql import DialectSQL +class FunctionTypingTest(fixtures.TestBase, AssertsExecutionResults): + __only_on__ = "postgresql" + __backend__ = True + + def test_count_star(self, connection): + eq_(connection.scalar(func.count("*")), 1) + + def test_count_int(self, connection): + eq_(connection.scalar(func.count(1)), 1) + + class InsertTest(fixtures.TestBase, AssertsExecutionResults): __only_on__ = "postgresql" @@ -956,23 +969,42 @@ class MatchTest(fixtures.TablesTest, AssertsCompiledSQL): ], ) + def _strs_render_bind_casts(self, connection): + + return ( + connection.dialect._bind_typing_render_casts + and String().dialect_impl(connection.dialect).render_bind_cast + ) + @testing.requires.pyformat_paramstyle - def test_expression_pyformat(self): + def test_expression_pyformat(self, connection): matchtable = self.tables.matchtable - self.assert_compile( - matchtable.c.title.match("somstr"), - "matchtable.title @@ to_tsquery(%(title_1)s" ")", - ) + + if self._strs_render_bind_casts(connection): + self.assert_compile( + matchtable.c.title.match("somstr"), + "matchtable.title @@ to_tsquery(%(title_1)s::VARCHAR(200))", + ) + else: + self.assert_compile( + matchtable.c.title.match("somstr"), + "matchtable.title @@ to_tsquery(%(title_1)s)", + ) @testing.requires.format_paramstyle - def test_expression_positional(self): + def test_expression_positional(self, connection): matchtable = self.tables.matchtable - self.assert_compile( - matchtable.c.title.match("somstr"), - # note we assume current tested DBAPIs use emulated setinputsizes - # here, the cast is not strictly necessary - "matchtable.title @@ to_tsquery(%s::VARCHAR(200))", - ) + + if self._strs_render_bind_casts(connection): + self.assert_compile( + matchtable.c.title.match("somstr"), + "matchtable.title @@ to_tsquery(%s::VARCHAR(200))", + ) + else: + self.assert_compile( + matchtable.c.title.match("somstr"), + "matchtable.title @@ to_tsquery(%s)", + ) def test_simple_match(self, connection): matchtable = self.tables.matchtable @@ -1551,17 +1583,106 @@ class TableValuedRoundTripTest(fixtures.TestBase): [(14, 1), (41, 2), (7, 3), (54, 4), (9, 5), (49, 6)], ) - @testing.only_on( - "postgresql+psycopg2", - "I cannot get this to run at all on other drivers, " - "even selecting from a table", + @testing.combinations( + ( + type_coerce, + testing.fails("fails on all drivers"), + ), + ( + cast, + testing.fails("fails on all drivers"), + ), + ( + None, + testing.fails_on_everything_except( + ["postgresql+psycopg2"], + "I cannot get this to run at all on other drivers, " + "even selecting from a table", + ), + ), + argnames="cast_fn", ) - def test_render_derived_quoting(self, connection): + def test_render_derived_quoting_text(self, connection, cast_fn): + + value = ( + '[{"CaseSensitive":1,"the % value":"foo"}, ' + '{"CaseSensitive":"2","the % value":"bar"}]' + ) + + if cast_fn: + value = cast_fn(value, JSON) + fn = ( - func.json_to_recordset( # noqa - '[{"CaseSensitive":1,"the % value":"foo"}, ' - '{"CaseSensitive":"2","the % value":"bar"}]' + func.json_to_recordset(value) + .table_valued( + column("CaseSensitive", Integer), column("the % value", String) ) + .render_derived(with_types=True) + ) + + stmt = select(fn.c.CaseSensitive, fn.c["the % value"]) + + eq_(connection.execute(stmt).all(), [(1, "foo"), (2, "bar")]) + + @testing.combinations( + ( + type_coerce, + testing.fails("fails on all drivers"), + ), + ( + cast, + testing.fails("fails on all drivers"), + ), + ( + None, + testing.fails("Fails on all drivers"), + ), + argnames="cast_fn", + ) + def test_render_derived_quoting_text_to_json(self, connection, cast_fn): + + value = ( + '[{"CaseSensitive":1,"the % value":"foo"}, ' + '{"CaseSensitive":"2","the % value":"bar"}]' + ) + + if cast_fn: + value = cast_fn(value, JSON) + + # why wont this work?!?!? + # should be exactly json_to_recordset(to_json('string'::text)) + # + fn = ( + func.json_to_recordset(func.to_json(value)) + .table_valued( + column("CaseSensitive", Integer), column("the % value", String) + ) + .render_derived(with_types=True) + ) + + stmt = select(fn.c.CaseSensitive, fn.c["the % value"]) + + eq_(connection.execute(stmt).all(), [(1, "foo"), (2, "bar")]) + + @testing.combinations( + (type_coerce,), + (cast,), + (None, testing.fails("Fails on all PG backends")), + argnames="cast_fn", + ) + def test_render_derived_quoting_straight_json(self, connection, cast_fn): + # these all work + + value = [ + {"CaseSensitive": 1, "the % value": "foo"}, + {"CaseSensitive": "2", "the % value": "bar"}, + ] + + if cast_fn: + value = cast_fn(value, JSON) + + fn = ( + func.json_to_recordset(value) # noqa .table_valued( column("CaseSensitive", Integer), column("the % value", String) ) diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 4008881d21..5f8a41d1f8 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -2360,7 +2360,7 @@ class ArrayEnum(fixtures.TestBase): testing.combinations( sqltypes.ARRAY, postgresql.ARRAY, - (_ArrayOfEnum, testing.only_on("postgresql+psycopg2")), + (_ArrayOfEnum, testing.requires.psycopg_compatibility), argnames="array_cls", )(fn) ) @@ -3066,7 +3066,7 @@ class HStoreRoundTripTest(fixtures.TablesTest): @testing.fixture def non_native_hstore_connection(self, testing_engine): - local_engine = testing.requires.psycopg2_native_hstore.enabled + local_engine = testing.requires.native_hstore.enabled if local_engine: engine = testing_engine(options=dict(use_native_hstore=False)) @@ -3096,14 +3096,14 @@ class HStoreRoundTripTest(fixtures.TablesTest): )["1"] eq_(connection.scalar(select(expr)), "3") - @testing.requires.psycopg2_native_hstore + @testing.requires.native_hstore def test_insert_native(self, connection): self._test_insert(connection) def test_insert_python(self, non_native_hstore_connection): self._test_insert(non_native_hstore_connection) - @testing.requires.psycopg2_native_hstore + @testing.requires.native_hstore def test_criterion_native(self, connection): self._fixture_data(connection) self._test_criterion(connection) @@ -3134,7 +3134,7 @@ class HStoreRoundTripTest(fixtures.TablesTest): def test_fixed_round_trip_python(self, non_native_hstore_connection): self._test_fixed_round_trip(non_native_hstore_connection) - @testing.requires.psycopg2_native_hstore + @testing.requires.native_hstore def test_fixed_round_trip_native(self, connection): self._test_fixed_round_trip(connection) @@ -3154,11 +3154,11 @@ class HStoreRoundTripTest(fixtures.TablesTest): }, ) - @testing.requires.psycopg2_native_hstore + @testing.requires.native_hstore def test_unicode_round_trip_python(self, non_native_hstore_connection): self._test_unicode_round_trip(non_native_hstore_connection) - @testing.requires.psycopg2_native_hstore + @testing.requires.native_hstore def test_unicode_round_trip_native(self, connection): self._test_unicode_round_trip(connection) @@ -3167,7 +3167,7 @@ class HStoreRoundTripTest(fixtures.TablesTest): ): self._test_escaped_quotes_round_trip(non_native_hstore_connection) - @testing.requires.psycopg2_native_hstore + @testing.requires.native_hstore def test_escaped_quotes_round_trip_native(self, connection): self._test_escaped_quotes_round_trip(connection) @@ -3356,7 +3356,7 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): class _RangeTypeRoundTrip(fixtures.TablesTest): - __requires__ = "range_types", "psycopg2_compatibility" + __requires__ = "range_types", "psycopg_compatibility" __backend__ = True def extras(self): @@ -3364,8 +3364,18 @@ class _RangeTypeRoundTrip(fixtures.TablesTest): # older psycopg2 versions. if testing.against("postgresql+psycopg2cffi"): from psycopg2cffi import extras - else: + elif testing.against("postgresql+psycopg2"): from psycopg2 import extras + elif testing.against("postgresql+psycopg"): + from psycopg.types.range import Range + + class psycopg_extras: + def __getattr__(self, _): + return Range + + extras = psycopg_extras() + else: + assert False, "Unknonw dialect" return extras @classmethod diff --git a/test/engine/test_parseconnect.py b/test/engine/test_parseconnect.py index c69b332a8f..f12d32d5d0 100644 --- a/test/engine/test_parseconnect.py +++ b/test/engine/test_parseconnect.py @@ -1037,6 +1037,64 @@ class TestRegNewDBAPI(fixtures.TestBase): ) +class TestGetDialect(fixtures.TestBase): + @testing.requires.sqlite + @testing.combinations(True, False, None) + def test_is_async_to_create_engine(self, is_async): + def get_dialect_cls(url): + url = url.set(drivername="sqlite") + return url.get_dialect() + + global MockDialectGetDialect + MockDialectGetDialect = Mock() + MockDialectGetDialect.get_dialect_cls.side_effect = get_dialect_cls + MockDialectGetDialect.get_async_dialect_cls.side_effect = ( + get_dialect_cls + ) + + registry.register("mockdialect", __name__, "MockDialectGetDialect") + + from sqlalchemy.dialects import sqlite + + kw = {} + if is_async is not None: + kw["_is_async"] = is_async + e = create_engine("mockdialect://", **kw) + + eq_(e.dialect.name, "sqlite") + assert isinstance(e.dialect, sqlite.dialect) + + if is_async: + eq_( + MockDialectGetDialect.mock_calls, + [ + call.get_async_dialect_cls(url.make_url("mockdialect://")), + call.engine_created(e), + ], + ) + else: + eq_( + MockDialectGetDialect.mock_calls, + [ + call.get_dialect_cls(url.make_url("mockdialect://")), + call.engine_created(e), + ], + ) + MockDialectGetDialect.reset_mock() + u = url.make_url("mockdialect://") + u.get_dialect(**kw) + if is_async: + eq_( + MockDialectGetDialect.mock_calls, + [call.get_async_dialect_cls(u)], + ) + else: + eq_( + MockDialectGetDialect.mock_calls, + [call.get_dialect_cls(u)], + ) + + class MockDialect(DefaultDialect): @classmethod def dbapi(cls, **kw): diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py index afd027698e..f37d1893f0 100644 --- a/test/engine/test_reconnect.py +++ b/test/engine/test_reconnect.py @@ -365,7 +365,7 @@ class MockReconnectTest(fixtures.TestBase): ) self.mock_connect = call( - host="localhost", password="bar", user="foo", database="test" + host="localhost", password="bar", user="foo", dbname="test" ) # monkeypatch disconnect checker self.db.dialect.is_disconnect = lambda e, conn, cursor: isinstance( @@ -1321,6 +1321,7 @@ class InvalidateDuringResultTest(fixtures.TestBase): "+aiosqlite", "+aiomysql", "+asyncmy", + "+psycopg", ], "Buffers the result set and doesn't check for connection close", ) diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py index 44cf9388ca..aee71f8d55 100644 --- a/test/ext/asyncio/test_engine_py3k.py +++ b/test/ext/asyncio/test_engine_py3k.py @@ -634,7 +634,11 @@ class AsyncEventTest(EngineFixture): eq_( canary.mock_calls, - [mock.call(sync_conn, mock.ANY, "select 1", (), mock.ANY, False)], + [ + mock.call( + sync_conn, mock.ANY, "select 1", mock.ANY, mock.ANY, False + ) + ], ) @async_test @@ -651,7 +655,11 @@ class AsyncEventTest(EngineFixture): eq_( canary.mock_calls, - [mock.call(sync_conn, mock.ANY, "select 1", (), mock.ANY, False)], + [ + mock.call( + sync_conn, mock.ANY, "select 1", mock.ANY, mock.ANY, False + ) + ], ) @async_test diff --git a/test/requirements.py b/test/requirements.py index eeb71323bc..3934dd23fc 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -223,6 +223,7 @@ class DefaultRequirements(SuiteRequirements): def pyformat_paramstyle(self): return only_on( [ + "postgresql+psycopg", "postgresql+psycopg2", "postgresql+psycopg2cffi", "mysql+mysqlconnector", @@ -1161,7 +1162,10 @@ class DefaultRequirements(SuiteRequirements): @property def infinity_floats(self): return fails_on_everything_except( - "sqlite", "postgresql+psycopg2", "postgresql+asyncpg" + "sqlite", + "postgresql+psycopg2", + "postgresql+asyncpg", + "postgresql+psycopg", ) + skip_if( "postgresql+pg8000", "seems to work on pg14 only, not earlier?" ) @@ -1241,9 +1245,7 @@ class DefaultRequirements(SuiteRequirements): @property def range_types(self): def check_range_types(config): - if not against( - config, ["postgresql+psycopg2", "postgresql+psycopg2cffi"] - ): + if not self.psycopg_compatibility.enabled: return False try: with config.db.connect() as conn: @@ -1291,23 +1293,27 @@ class DefaultRequirements(SuiteRequirements): ) @property - def psycopg2_native_hstore(self): - return self.psycopg2_compatibility + def native_hstore(self): + return self.psycopg_compatibility @property def psycopg2_compatibility(self): return only_on(["postgresql+psycopg2", "postgresql+psycopg2cffi"]) @property - def psycopg2_or_pg8000_compatibility(self): + def psycopg_compatibility(self): return only_on( [ "postgresql+psycopg2", "postgresql+psycopg2cffi", - "postgresql+pg8000", + "postgresql+psycopg", ] ) + @property + def psycopg_or_pg8000_compatibility(self): + return only_on([self.psycopg_compatibility, "postgresql+pg8000"]) + @property def percent_schema_names(self): return skip_if( @@ -1690,3 +1696,8 @@ class DefaultRequirements(SuiteRequirements): def reflect_tables_no_columns(self): # so far sqlite, mariadb, mysql don't support this return only_on(["postgresql"]) + + @property + def json_deserializer_binary(self): + "indicates if the json_deserializer function is called with bytes" + return only_on(["postgresql+psycopg"]) diff --git a/test/sql/test_types.py b/test/sql/test_types.py index f63e1e01bc..35467de941 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -2459,15 +2459,16 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): @testing.combinations( (True, "omit_alias"), (False, "with_alias"), id_="ai", argnames="omit" ) - @testing.provide_metadata @testing.skip_if("mysql < 8") - def test_duplicate_values_accepted(self, native, omit): + def test_duplicate_values_accepted( + self, metadata, connection, native, omit + ): foo_enum = pep435_enum("foo_enum") foo_enum("one", 1, "two") foo_enum("three", 3, "four") tbl = sa.Table( "foo_table", - self.metadata, + metadata, sa.Column("id", sa.Integer), sa.Column( "data", @@ -2481,7 +2482,7 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): ) t = sa.table("foo_table", sa.column("id"), sa.column("data")) - self.metadata.create_all(testing.db) + metadata.create_all(connection) if omit: with expect_raises( ( @@ -2491,29 +2492,27 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): exc.DBAPIError, ) ): - with testing.db.begin() as conn: - conn.execute( - t.insert(), - [ - {"id": 1, "data": "four"}, - {"id": 2, "data": "three"}, - ], - ) - else: - with testing.db.begin() as conn: - conn.execute( + connection.execute( t.insert(), - [{"id": 1, "data": "four"}, {"id": 2, "data": "three"}], + [ + {"id": 1, "data": "four"}, + {"id": 2, "data": "three"}, + ], ) + else: + connection.execute( + t.insert(), + [{"id": 1, "data": "four"}, {"id": 2, "data": "three"}], + ) - eq_( - conn.execute(t.select().order_by(t.c.id)).fetchall(), - [(1, "four"), (2, "three")], - ) - eq_( - conn.execute(tbl.select().order_by(tbl.c.id)).fetchall(), - [(1, foo_enum.three), (2, foo_enum.three)], - ) + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "four"), (2, "three")], + ) + eq_( + connection.execute(tbl.select().order_by(tbl.c.id)).fetchall(), + [(1, foo_enum.three), (2, foo_enum.three)], + ) MyPickleType = None diff --git a/tox.ini b/tox.ini index d8ba67a440..9f6ccb0c65 100644 --- a/tox.ini +++ b/tox.ini @@ -27,6 +27,7 @@ deps= postgresql: .[postgresql] postgresql: .[postgresql_asyncpg]; python_version >= '3' postgresql: .[postgresql_pg8000]; python_version >= '3' + postgresql: .[postgresql_psycopg]; python_version >= '3' mysql: .[mysql] mysql: .[pymysql] @@ -43,6 +44,8 @@ deps= dbapimain-postgresql: git+https://github.com/psycopg/psycopg2.git#egg=psycopg2 dbapimain-postgresql: git+https://github.com/MagicStack/asyncpg.git#egg=asyncpg dbapimain-postgresql: git+https://github.com/tlocke/pg8000.git#egg=pg8000 + dbapimain-postgresql: git+https://github.com/psycopg/psycopg.git#egg=psycopg&subdirectory=psycopg + dbapimain-postgresql: git+https://github.com/psycopg/psycopg.git#egg=psycopg-c&subdirectory=psycopg_c dbapimain-mysql: git+https://github.com/PyMySQL/mysqlclient-python.git#egg=mysqlclient dbapimain-mysql: git+https://github.com/PyMySQL/PyMySQL.git#egg=pymysql @@ -89,14 +92,13 @@ setenv= sqlite: SQLITE={env:TOX_SQLITE:--db sqlite} sqlite_file: SQLITE={env:TOX_SQLITE_FILE:--db sqlite_file} - py3{,5,6,7,8,9,10,11}-sqlite: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver aiosqlite} - py3{,5,6,7,8,9}-sqlite_file: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver aiosqlite --dbdriver pysqlcipher} + py3{,7,8,9,10,11}-sqlite: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver aiosqlite} + py3{,7,8,9}-sqlite_file: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver aiosqlite --dbdriver pysqlcipher} # omit pysqlcipher for Python 3.10 py3{,10,11}-sqlite_file: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver aiosqlite} postgresql: POSTGRESQL={env:TOX_POSTGRESQL:--db postgresql} - py2{,7}-postgresql: POSTGRESQL={env:TOX_POSTGRESQL_PY2K:{env:TOX_POSTGRESQL:--db postgresql}} - py3{,5,6,7,8,9,10,11}-postgresql: EXTRA_PG_DRIVERS={env:EXTRA_PG_DRIVERS:--dbdriver psycopg2 --dbdriver asyncpg --dbdriver pg8000} + py3{,7,8,9,10,11}-postgresql: EXTRA_PG_DRIVERS={env:EXTRA_PG_DRIVERS:--dbdriver psycopg2 --dbdriver asyncpg --dbdriver pg8000 --dbdriver psycopg --dbdriver psycopg_async} mysql: MYSQL={env:TOX_MYSQL:--db mysql} py2{,7}-mysql: MYSQL={env:TOX_MYSQL_PY2K:{env:TOX_MYSQL:--db mysql}} -- 2.47.2