From: Tony Locke Date: Sun, 2 Aug 2020 19:19:26 +0000 (-0400) Subject: Update dialect for pg8000 version 1.16.0 X-Git-Tag: rel_1_4_0b1~162^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=06f1929b866abc2af0ff5c838e472a8b1c98d6e6;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Update dialect for pg8000 version 1.16.0 The pg8000 dialect has been revised and modernized for the most recent version of the pg8000 driver for PostgreSQL. Changes to the dialect include: * All data types are now sent as text rather than binary. * Using adapters, custom types can be plugged in to pg8000. * Previously, named prepared statements were used for all statements. Now unnamed prepared statements are used by default, and named prepared statements can be used explicitly by calling the Connection.prepare() method, which returns a PreparedStatement object. Pull request courtesy Tony Locke. Notes by Mike: to get this all working it was needed to break up JSONIndexType into "str" and "int" subtypes; this will be needed for any dialect that is dependent on setinputsizes(). also includes @caselit's idea to include query params in the dbdriver parameter. Co-authored-by: Mike Bayer Closes: #5451 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/5451 Pull-request-sha: 639751ca9c7544801b9ede02e6cbe15a16c59c82 Change-Id: I2869bc52c330916773a41d11d12c297aecc8fcd8 --- diff --git a/doc/build/changelog/unreleased_14/pg8000.rst b/doc/build/changelog/unreleased_14/pg8000.rst new file mode 100644 index 0000000000..17c0a9d1c7 --- /dev/null +++ b/doc/build/changelog/unreleased_14/pg8000.rst @@ -0,0 +1,18 @@ +.. change:: + :tags: bug, postgresql + + The pg8000 dialect has been revised and modernized for the most recent + version of the pg8000 driver for PostgreSQL. Changes to the dialect + include: + + * All data types are now sent as text rather than binary. + + * Using adapters, custom types can be plugged in to pg8000. + + * Previously, named prepared statements were used for all statements. + Now unnamed prepared statements are used by default, and named + prepared statements can be used explicitly by calling the + Connection.prepare() method, which returns a PreparedStatement + object. + + Pull request courtesy Tony Locke. diff --git a/lib/sqlalchemy/dialects/mysql/provision.py b/lib/sqlalchemy/dialects/mysql/provision.py index b86056da6c..bbe752d782 100644 --- a/lib/sqlalchemy/dialects/mysql/provision.py +++ b/lib/sqlalchemy/dialects/mysql/provision.py @@ -9,7 +9,7 @@ from ...testing.provision import temp_table_keyword_args @generate_driver_url.for_db("mysql", "mariadb") -def generate_driver_url(url, driver): +def generate_driver_url(url, driver, query): backend = url.get_backend_name() if backend == "mysql": @@ -18,7 +18,10 @@ def generate_driver_url(url, driver): backend = "mariadb" new_url = copy.copy(url) + new_url.query = dict(new_url.query) new_url.drivername = "%s+%s" % (backend, driver) + new_url.query.update(query) + try: new_url.get_dialect() except exc.NoSuchModuleError: diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py index 57c8f5a9af..e08332a570 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg8000.py +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -9,13 +9,11 @@ r""" :name: pg8000 :dbapi: pg8000 :connectstring: postgresql+pg8000://user:password@host:port/dbname[?key=value&key=value...] - :url: https://pythonhosted.org/pg8000/ + :url: https://pypi.org/project/pg8000/ -.. note:: - - The pg8000 dialect is **not tested as part of SQLAlchemy's continuous - integration** and may have unresolved issues. The recommended PostgreSQL - dialect is psycopg2. +.. versionchanged:: 1.4 The pg8000 dialect has been updated for version + 1.16.5 and higher, and is again part of SQLAlchemy's continuous integration + with full feature support. .. _pg8000_unicode: @@ -56,9 +54,6 @@ of the :ref:`psycopg2 ` dialect: * ``SERIALIZABLE`` * ``AUTOCOMMIT`` -.. versionadded:: 0.9.5 support for AUTOCOMMIT isolation level when using - pg8000. - .. seealso:: :ref:`postgresql_isolation_level` @@ -74,12 +69,16 @@ from uuid import UUID as _python_UUID from .base import _DECIMAL_TYPES from .base import _FLOAT_TYPES from .base import _INT_TYPES +from .base import ENUM +from .base import INTERVAL from .base import PGCompiler from .base import PGDialect from .base import PGExecutionContext from .base import PGIdentifierPreparer from .base import UUID from .json import JSON +from .json import JSONB +from .json import JSONPathType from ... import exc from ... import processors from ... import types as sqltypes @@ -125,6 +124,40 @@ class _PGJSON(JSON): else: return super(_PGJSON, self).result_processor(dialect, coltype) + def get_dbapi_type(self, dbapi): + return dbapi.JSON + + +class _PGJSONB(JSONB): + def result_processor(self, dialect, coltype): + if dialect._dbapi_version > (1, 10, 1): + return None # Has native JSON + else: + return super(_PGJSON, self).result_processor(dialect, coltype) + + def get_dbapi_type(self, dbapi): + return dbapi.JSONB + + +class _PGJSONIndexType(sqltypes.JSON.JSONIndexType): + def get_dbapi_type(self, dbapi): + raise NotImplementedError("should not be here") + + +class _PGJSONIntIndexType(sqltypes.JSON.JSONIntIndexType): + def get_dbapi_type(self, dbapi): + return dbapi.INTEGER + + +class _PGJSONStrIndexType(sqltypes.JSON.JSONStrIndexType): + def get_dbapi_type(self, dbapi): + return dbapi.STRING + + +class _PGJSONPathType(JSONPathType): + def get_dbapi_type(self, dbapi): + return 1009 + class _PGUUID(UUID): def bind_processor(self, dialect): @@ -148,8 +181,67 @@ class _PGUUID(UUID): return process +class _PGEnum(ENUM): + def get_dbapi_type(self, dbapi): + return dbapi.UNKNOWN + + +class _PGInterval(INTERVAL): + def get_dbapi_type(self, dbapi): + return dbapi.INTERVAL + + @classmethod + def adapt_emulated_to_native(cls, interval, **kw): + return _PGInterval(precision=interval.second_precision) + + +class _PGTimeStamp(sqltypes.DateTime): + def get_dbapi_type(self, dbapi): + if self.timezone: + # TIMESTAMPTZOID + return 1184 + else: + # TIMESTAMPOID + return 1114 + + +class _PGTime(sqltypes.Time): + def get_dbapi_type(self, dbapi): + return dbapi.TIME + + +class _PGInteger(sqltypes.Integer): + def get_dbapi_type(self, dbapi): + return dbapi.INTEGER + + +class _PGSmallInteger(sqltypes.SmallInteger): + def get_dbapi_type(self, dbapi): + return dbapi.INTEGER + + +class _PGNullType(sqltypes.NullType): + def get_dbapi_type(self, dbapi): + return dbapi.NULLTYPE + + +class _PGBigInteger(sqltypes.BigInteger): + def get_dbapi_type(self, dbapi): + return dbapi.BIGINTEGER + + +class _PGBoolean(sqltypes.Boolean): + def get_dbapi_type(self, dbapi): + return dbapi.BOOLEAN + + class PGExecutionContext_pg8000(PGExecutionContext): - pass + def pre_exec(self): + if not self.compiled: + return + + if self.dialect._dbapi_version > (1, 16, 0): + self.set_input_sizes() class PGCompiler_pg8000(PGCompiler): @@ -160,20 +252,11 @@ class PGCompiler_pg8000(PGCompiler): + self.process(binary.right, **kw) ) - def post_process_text(self, text): - if "%%" in text: - util.warn( - "The SQLAlchemy postgresql dialect " - "now automatically escapes '%' in text() " - "expressions to '%%'." - ) - return text.replace("%", "%%") - class PGIdentifierPreparer_pg8000(PGIdentifierPreparer): - def _escape_identifier(self, value): - value = value.replace(self.escape_quote, self.escape_to_quote) - return value.replace("%", "%%") + def __init__(self, *args, **kwargs): + PGIdentifierPreparer.__init__(self, *args, **kwargs) + self._double_percents = False class PGDialect_pg8000(PGDialect): @@ -195,9 +278,23 @@ class PGDialect_pg8000(PGDialect): { sqltypes.Numeric: _PGNumericNoBind, sqltypes.Float: _PGNumeric, - JSON: _PGJSON, sqltypes.JSON: _PGJSON, + sqltypes.Boolean: _PGBoolean, + sqltypes.NullType: _PGNullType, + JSONB: _PGJSONB, + sqltypes.JSON.JSONPathType: _PGJSONPathType, + sqltypes.JSON.JSONIndexType: _PGJSONIndexType, + sqltypes.JSON.JSONIntIndexType: _PGJSONIntIndexType, + sqltypes.JSON.JSONStrIndexType: _PGJSONStrIndexType, UUID: _PGUUID, + sqltypes.Interval: _PGInterval, + INTERVAL: _PGInterval, + sqltypes.DateTime: _PGTimeStamp, + sqltypes.Time: _PGTime, + sqltypes.Integer: _PGInteger, + sqltypes.SmallInteger: _PGSmallInteger, + sqltypes.BigInteger: _PGBigInteger, + sqltypes.Enum: _PGEnum, }, ) @@ -313,6 +410,17 @@ class PGDialect_pg8000(PGDialect): fns.append(on_connect) + if self._dbapi_version > (1, 16, 0) and self._json_deserializer: + + def on_connect(conn): + # json + conn.register_in_adapter(114, self._json_deserializer) + + # jsonb + conn.register_in_adapter(3802, self._json_deserializer) + + fns.append(on_connect) + if len(fns) > 0: def on_connect(conn): diff --git a/lib/sqlalchemy/dialects/postgresql/provision.py b/lib/sqlalchemy/dialects/postgresql/provision.py index eb82a411eb..6c6dc4be64 100644 --- a/lib/sqlalchemy/dialects/postgresql/provision.py +++ b/lib/sqlalchemy/dialects/postgresql/provision.py @@ -1,30 +1,13 @@ -import copy import time from ... import exc from ... import text from ...testing.provision import create_db from ...testing.provision import drop_db -from ...testing.provision import generate_driver_url from ...testing.provision import log from ...testing.provision import temp_table_keyword_args -@generate_driver_url.for_db("postgresql") -def generate_driver_url(url, driver): - new_url = copy.copy(url) - new_url.drivername = "postgresql+%s" % driver - if new_url.get_driver_name() == "asyncpg": - new_url.query = dict(new_url.query) - new_url.query["async_fallback"] = "true" - try: - new_url.get_dialect() - except exc.NoSuchModuleError: - return None - else: - return new_url - - @create_db.for_db("postgresql") def _pg_create_db(cfg, eng, ident): template_db = cfg.options.postgresql_templatedb diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index fe74be8235..f9fabbeed5 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -346,12 +346,12 @@ def _expect_raises(except_cls, msg=None, check_context=False): assert success, "Callable did not raise an exception" -def expect_raises(except_cls): - return _expect_raises(except_cls, check_context=True) +def expect_raises(except_cls, check_context=True): + return _expect_raises(except_cls, check_context=check_context) -def expect_raises_message(except_cls, msg): - return _expect_raises(except_cls, msg=msg, check_context=True) +def expect_raises_message(except_cls, msg, check_context=True): + return _expect_raises(except_cls, msg=msg, check_context=check_context) class AssertsCompiledSQL(object): diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py index 21bacfca2f..094d1ea94b 100644 --- a/lib/sqlalchemy/testing/provision.py +++ b/lib/sqlalchemy/testing/provision.py @@ -7,6 +7,7 @@ from . import engines from .. import exc from ..engine import url as sa_url from ..util import compat +from ..util import parse_qsl log = logging.getLogger(__name__) @@ -85,7 +86,7 @@ def generate_db_urls(db_urls, extra_drivers): --dburi postgresql://db1 \ --dburi postgresql://db2 \ --dburi postgresql://db2 \ - --dbdriver=psycopg2 --dbdriver=asyncpg + --dbdriver=psycopg2 --dbdriver=asyncpg?async_fallback=true Noting that the default postgresql driver is psycopg2. the output would be:: @@ -139,21 +140,34 @@ def _generate_driver_urls(url, extra_drivers): main_driver = url.get_driver_name() extra_drivers.discard(main_driver) - url = generate_driver_url(url, main_driver) + url = generate_driver_url(url, main_driver, {}) yield str(url) for drv in list(extra_drivers): - new_url = generate_driver_url(url, drv) + + if "?" in drv: + + driver_only, query_str = drv.split("?", 1) + + query = parse_qsl(query_str) + else: + driver_only = drv + query = {} + + new_url = generate_driver_url(url, driver_only, query) if new_url: extra_drivers.remove(drv) + yield str(new_url) @register.init -def generate_driver_url(url, driver): +def generate_driver_url(url, driver, query): backend = url.get_backend_name() new_url = copy.copy(url) + new_url.query = dict(new_url.query) new_url.drivername = "%s+%s" % (backend, driver) + new_url.query.update(query) try: new_url.get_dialect() except exc.NoSuchModuleError: diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index 5e6ac1eabd..6a390231bb 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -34,7 +34,6 @@ from ... import testing from ... import Text from ... import Time from ... import TIMESTAMP -from ... import type_coerce from ... import TypeDecorator from ... import Unicode from ... import UnicodeText @@ -1161,37 +1160,6 @@ class JSONStringCastIndexTest(_LiteralRoundTripFixture, fixtures.TablesTest): and_(name == "r6", cast(col["b"], String) == '"some value"'), "r6" ) - def test_crit_against_int_basic(self): - name = self.tables.data_table.c.name - col = self.tables.data_table.c["data"] - - self._test_index_criteria( - and_(name == "r6", cast(col["a"], String) == "5"), "r6" - ) - - def _dont_test_crit_against_string_coerce_type(self): - name = self.tables.data_table.c.name - col = self.tables.data_table.c["data"] - - self._test_index_criteria( - and_( - name == "r6", - cast(col["b"], String) == type_coerce("some value", JSON), - ), - "r6", - test_literal=False, - ) - - def _dont_test_crit_against_int_coerce_type(self): - name = self.tables.data_table.c.name - col = self.tables.data_table.c["data"] - - self._test_index_criteria( - and_(name == "r6", cast(col["a"], String) == type_coerce(5, JSON)), - "r6", - test_literal=False, - ) - __all__ = ( "UnicodeVarcharTest", diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index d15e3a843c..6eaa3295b9 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -982,7 +982,7 @@ $$ LANGUAGE plpgsql; t = Table("t", m, Column("c", type_, primary_key=True)) if version: - dialect = postgresql.dialect() + dialect = testing.db.dialect.__class__() dialect._get_server_version_info = mock.Mock( return_value=version ) diff --git a/test/engine/test_deprecations.py b/test/engine/test_deprecations.py index 0d02ab3e7b..f09f0f1e12 100644 --- a/test/engine/test_deprecations.py +++ b/test/engine/test_deprecations.py @@ -579,10 +579,6 @@ class RawExecuteTest(fixtures.TablesTest): Column("user_name", VARCHAR(20)), ) - @testing.fails_on( - "postgresql+pg8000", - "pg8000 still doesn't allow single paren without params", - ) def test_no_params_option(self, connection): stmt = ( "SELECT '%'" diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index 89d5c63486..fd42224ebb 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -77,10 +77,6 @@ class ExecuteTest(fixtures.TablesTest): Column("user_name", VARCHAR(20)), ) - @testing.fails_on( - "postgresql+pg8000", - "pg8000 still doesn't allow single paren without params", - ) def test_no_params_option(self): stmt = ( "SELECT '%'" diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py index 48eb485cb7..53a5ec6f4d 100644 --- a/test/engine/test_reconnect.py +++ b/test/engine/test_reconnect.py @@ -17,6 +17,7 @@ from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import assert_raises_message_context_ok from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ @@ -1312,12 +1313,10 @@ class PrePingRealTest(fixtures.TestBase): eq_(conn.execute(select(1)).scalar(), 1) conn.close() - def exercise_stale_connection(): + with expect_raises(engine.dialect.dbapi.Error, check_context=False): curs = stale_connection.cursor() curs.execute("select 1") - assert_raises(engine.dialect.dbapi.Error, exercise_stale_connection) - def test_pre_ping_db_stays_shutdown(self): engine = engines.reconnecting_engine(options={"pool_pre_ping": True}) diff --git a/test/requirements.py b/test/requirements.py index 99a3605658..99a6f5a3b4 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -1017,7 +1017,7 @@ class DefaultRequirements(SuiteRequirements): @property def json_array_indexes(self): - return self.json_type + fails_if("+pg8000") + return self.json_type @property def datetime_literals(self): @@ -1209,20 +1209,6 @@ class DefaultRequirements(SuiteRequirements): "Firebird still has FP inaccuracy even " "with only four decimal places", ), - ( - "postgresql+pg8000", - None, - None, - "postgresql+pg8000 has FP inaccuracy even with " - "only four decimal places ", - ), - ( - "postgresql+psycopg2cffi", - None, - None, - "postgresql+psycopg2cffi has FP inaccuracy even with " - "only four decimal places ", - ), ] ) @@ -1253,7 +1239,7 @@ class DefaultRequirements(SuiteRequirements): @property def duplicate_key_raises_integrity_error(self): - return fails_on("postgresql+pg8000") + return exclusions.open() def _has_pg_extension(self, name): def check(config): diff --git a/tox.ini b/tox.ini index 92fe031724..ac95dc42c6 100644 --- a/tox.ini +++ b/tox.ini @@ -21,7 +21,8 @@ deps=pytest!=3.9.1,!=3.9.2 mock; python_version < '3.3' importlib_metadata; python_version < '3.8' postgresql: .[postgresql] - postgresql: .[postgresql_asyncpg] + postgresql: .[postgresql_asyncpg]; python_version >= '3' + postgresql: .[postgresql_pg8000]; python_version >= '3' mysql: .[mysql] mysql: .[pymysql] oracle: .[oracle] @@ -66,11 +67,12 @@ setenv= sqlite_file: SQLITE={env:TOX_SQLITE_FILE:--db sqlite_file} postgresql: POSTGRESQL={env:TOX_POSTGRESQL:--db postgresql} - postgresql: EXTRA_PG_DRIVERS={env:EXTRA_PG_DRIVERS:--dbdriver psycopg2 --dbdriver asyncpg} + py3{,5,6,7,8,9,10,11}-postgresql: EXTRA_PG_DRIVERS={env:EXTRA_PG_DRIVERS:--dbdriver psycopg2 --dbdriver asyncpg?async_fallback=true --dbdriver pg8000} mysql: MYSQL={env:TOX_MYSQL:--db mysql} mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql} + mssql: MSSQL={env:TOX_MSSQL:--db mssql} oracle,mssql,sqlite_file: IDENTS=--write-idents db_idents.txt