]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Update dialect for pg8000 version 1.16.0
authorTony Locke <tlocke@tlocke.org.uk>
Sun, 2 Aug 2020 19:19:26 +0000 (15:19 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 18 Aug 2020 15:12:16 +0000 (11:12 -0400)
The pg8000 dialect has been revised and modernized for the most recent
version of the pg8000 driver for PostgreSQL.  Changes to the dialect
include:

* All data types are now sent as text rather than binary.

* Using adapters, custom types can be plugged in to pg8000.

* Previously, named prepared statements were used for all statements.
  Now unnamed prepared statements are used by default, and named
  prepared statements can be used explicitly by calling the
  Connection.prepare() method, which returns a PreparedStatement
  object.

Pull request courtesy Tony Locke.

Notes by Mike: to get this all working it was needed to break
up JSONIndexType into "str" and "int" subtypes; this will be
needed for any dialect that is dependent on setinputsizes().

also includes @caselit's idea to include query params
in the dbdriver parameter.

Co-authored-by: Mike Bayer <mike_mp@zzzcomputing.com>
Closes: #5451
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/5451
Pull-request-sha: 639751ca9c7544801b9ede02e6cbe15a16c59c82

Change-Id: I2869bc52c330916773a41d11d12c297aecc8fcd8

13 files changed:
doc/build/changelog/unreleased_14/pg8000.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/provision.py
lib/sqlalchemy/dialects/postgresql/pg8000.py
lib/sqlalchemy/dialects/postgresql/provision.py
lib/sqlalchemy/testing/assertions.py
lib/sqlalchemy/testing/provision.py
lib/sqlalchemy/testing/suite/test_types.py
test/dialect/postgresql/test_dialect.py
test/engine/test_deprecations.py
test/engine/test_execute.py
test/engine/test_reconnect.py
test/requirements.py
tox.ini

diff --git a/doc/build/changelog/unreleased_14/pg8000.rst b/doc/build/changelog/unreleased_14/pg8000.rst
new file mode 100644 (file)
index 0000000..17c0a9d
--- /dev/null
@@ -0,0 +1,18 @@
+.. change::
+    :tags: bug, postgresql
+
+    The pg8000 dialect has been revised and modernized for the most recent
+    version of the pg8000 driver for PostgreSQL.  Changes to the dialect
+    include:
+
+        * All data types are now sent as text rather than binary.
+
+        * Using adapters, custom types can be plugged in to pg8000.
+
+        * Previously, named prepared statements were used for all statements.
+          Now unnamed prepared statements are used by default, and named
+          prepared statements can be used explicitly by calling the
+          Connection.prepare() method, which returns a PreparedStatement
+          object.
+
+    Pull request courtesy Tony Locke.
index b86056da6cfea059e0fc88e39bed860cbe5cc6d4..bbe752d78273d497ac34ae985b18baec78b94148 100644 (file)
@@ -9,7 +9,7 @@ from ...testing.provision import temp_table_keyword_args
 
 
 @generate_driver_url.for_db("mysql", "mariadb")
-def generate_driver_url(url, driver):
+def generate_driver_url(url, driver, query):
     backend = url.get_backend_name()
 
     if backend == "mysql":
@@ -18,7 +18,10 @@ def generate_driver_url(url, driver):
             backend = "mariadb"
 
     new_url = copy.copy(url)
+    new_url.query = dict(new_url.query)
     new_url.drivername = "%s+%s" % (backend, driver)
+    new_url.query.update(query)
+
     try:
         new_url.get_dialect()
     except exc.NoSuchModuleError:
index 57c8f5a9af2591225be23c43162601645f83ad43..e08332a570a5da1afa5c48f458727137ba4f1f02 100644 (file)
@@ -9,13 +9,11 @@ r"""
     :name: pg8000
     :dbapi: pg8000
     :connectstring: postgresql+pg8000://user:password@host:port/dbname[?key=value&key=value...]
-    :url: https://pythonhosted.org/pg8000/
+    :url: https://pypi.org/project/pg8000/
 
-.. note::
-
-    The pg8000 dialect is **not tested as part of SQLAlchemy's continuous
-    integration** and may have unresolved issues.  The recommended PostgreSQL
-    dialect is psycopg2.
+.. versionchanged:: 1.4  The pg8000 dialect has been updated for version
+   1.16.5 and higher, and is again part of SQLAlchemy's continuous integration
+   with full feature support.
 
 .. _pg8000_unicode:
 
@@ -56,9 +54,6 @@ of the :ref:`psycopg2 <psycopg2_isolation_level>` dialect:
 * ``SERIALIZABLE``
 * ``AUTOCOMMIT``
 
-.. versionadded:: 0.9.5 support for AUTOCOMMIT isolation level when using
-   pg8000.
-
 .. seealso::
 
     :ref:`postgresql_isolation_level`
@@ -74,12 +69,16 @@ from uuid import UUID as _python_UUID
 from .base import _DECIMAL_TYPES
 from .base import _FLOAT_TYPES
 from .base import _INT_TYPES
+from .base import ENUM
+from .base import INTERVAL
 from .base import PGCompiler
 from .base import PGDialect
 from .base import PGExecutionContext
 from .base import PGIdentifierPreparer
 from .base import UUID
 from .json import JSON
+from .json import JSONB
+from .json import JSONPathType
 from ... import exc
 from ... import processors
 from ... import types as sqltypes
@@ -125,6 +124,40 @@ class _PGJSON(JSON):
         else:
             return super(_PGJSON, self).result_processor(dialect, coltype)
 
+    def get_dbapi_type(self, dbapi):
+        return dbapi.JSON
+
+
+class _PGJSONB(JSONB):
+    def result_processor(self, dialect, coltype):
+        if dialect._dbapi_version > (1, 10, 1):
+            return None  # Has native JSON
+        else:
+            return super(_PGJSON, self).result_processor(dialect, coltype)
+
+    def get_dbapi_type(self, dbapi):
+        return dbapi.JSONB
+
+
+class _PGJSONIndexType(sqltypes.JSON.JSONIndexType):
+    def get_dbapi_type(self, dbapi):
+        raise NotImplementedError("should not be here")
+
+
+class _PGJSONIntIndexType(sqltypes.JSON.JSONIntIndexType):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.INTEGER
+
+
+class _PGJSONStrIndexType(sqltypes.JSON.JSONStrIndexType):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.STRING
+
+
+class _PGJSONPathType(JSONPathType):
+    def get_dbapi_type(self, dbapi):
+        return 1009
+
 
 class _PGUUID(UUID):
     def bind_processor(self, dialect):
@@ -148,8 +181,67 @@ class _PGUUID(UUID):
             return process
 
 
+class _PGEnum(ENUM):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.UNKNOWN
+
+
+class _PGInterval(INTERVAL):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.INTERVAL
+
+    @classmethod
+    def adapt_emulated_to_native(cls, interval, **kw):
+        return _PGInterval(precision=interval.second_precision)
+
+
+class _PGTimeStamp(sqltypes.DateTime):
+    def get_dbapi_type(self, dbapi):
+        if self.timezone:
+            # TIMESTAMPTZOID
+            return 1184
+        else:
+            # TIMESTAMPOID
+            return 1114
+
+
+class _PGTime(sqltypes.Time):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.TIME
+
+
+class _PGInteger(sqltypes.Integer):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.INTEGER
+
+
+class _PGSmallInteger(sqltypes.SmallInteger):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.INTEGER
+
+
+class _PGNullType(sqltypes.NullType):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.NULLTYPE
+
+
+class _PGBigInteger(sqltypes.BigInteger):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.BIGINTEGER
+
+
+class _PGBoolean(sqltypes.Boolean):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.BOOLEAN
+
+
 class PGExecutionContext_pg8000(PGExecutionContext):
-    pass
+    def pre_exec(self):
+        if not self.compiled:
+            return
+
+        if self.dialect._dbapi_version > (1, 16, 0):
+            self.set_input_sizes()
 
 
 class PGCompiler_pg8000(PGCompiler):
@@ -160,20 +252,11 @@ class PGCompiler_pg8000(PGCompiler):
             + self.process(binary.right, **kw)
         )
 
-    def post_process_text(self, text):
-        if "%%" in text:
-            util.warn(
-                "The SQLAlchemy postgresql dialect "
-                "now automatically escapes '%' in text() "
-                "expressions to '%%'."
-            )
-        return text.replace("%", "%%")
-
 
 class PGIdentifierPreparer_pg8000(PGIdentifierPreparer):
-    def _escape_identifier(self, value):
-        value = value.replace(self.escape_quote, self.escape_to_quote)
-        return value.replace("%", "%%")
+    def __init__(self, *args, **kwargs):
+        PGIdentifierPreparer.__init__(self, *args, **kwargs)
+        self._double_percents = False
 
 
 class PGDialect_pg8000(PGDialect):
@@ -195,9 +278,23 @@ class PGDialect_pg8000(PGDialect):
         {
             sqltypes.Numeric: _PGNumericNoBind,
             sqltypes.Float: _PGNumeric,
-            JSON: _PGJSON,
             sqltypes.JSON: _PGJSON,
+            sqltypes.Boolean: _PGBoolean,
+            sqltypes.NullType: _PGNullType,
+            JSONB: _PGJSONB,
+            sqltypes.JSON.JSONPathType: _PGJSONPathType,
+            sqltypes.JSON.JSONIndexType: _PGJSONIndexType,
+            sqltypes.JSON.JSONIntIndexType: _PGJSONIntIndexType,
+            sqltypes.JSON.JSONStrIndexType: _PGJSONStrIndexType,
             UUID: _PGUUID,
+            sqltypes.Interval: _PGInterval,
+            INTERVAL: _PGInterval,
+            sqltypes.DateTime: _PGTimeStamp,
+            sqltypes.Time: _PGTime,
+            sqltypes.Integer: _PGInteger,
+            sqltypes.SmallInteger: _PGSmallInteger,
+            sqltypes.BigInteger: _PGBigInteger,
+            sqltypes.Enum: _PGEnum,
         },
     )
 
@@ -313,6 +410,17 @@ class PGDialect_pg8000(PGDialect):
 
             fns.append(on_connect)
 
+        if self._dbapi_version > (1, 16, 0) and self._json_deserializer:
+
+            def on_connect(conn):
+                # json
+                conn.register_in_adapter(114, self._json_deserializer)
+
+                # jsonb
+                conn.register_in_adapter(3802, self._json_deserializer)
+
+            fns.append(on_connect)
+
         if len(fns) > 0:
 
             def on_connect(conn):
index eb82a411eb2b0c99f3f76c70cac85a38a4b69555..6c6dc4be643f5c65aa59102f6c8c59cbd349ae2f 100644 (file)
@@ -1,30 +1,13 @@
-import copy
 import time
 
 from ... import exc
 from ... import text
 from ...testing.provision import create_db
 from ...testing.provision import drop_db
-from ...testing.provision import generate_driver_url
 from ...testing.provision import log
 from ...testing.provision import temp_table_keyword_args
 
 
-@generate_driver_url.for_db("postgresql")
-def generate_driver_url(url, driver):
-    new_url = copy.copy(url)
-    new_url.drivername = "postgresql+%s" % driver
-    if new_url.get_driver_name() == "asyncpg":
-        new_url.query = dict(new_url.query)
-        new_url.query["async_fallback"] = "true"
-    try:
-        new_url.get_dialect()
-    except exc.NoSuchModuleError:
-        return None
-    else:
-        return new_url
-
-
 @create_db.for_db("postgresql")
 def _pg_create_db(cfg, eng, ident):
     template_db = cfg.options.postgresql_templatedb
index fe74be8235feee67a24be13ed74c12c22cc05461..f9fabbeed5df7ea366afb6c23d215b0532819639 100644 (file)
@@ -346,12 +346,12 @@ def _expect_raises(except_cls, msg=None, check_context=False):
     assert success, "Callable did not raise an exception"
 
 
-def expect_raises(except_cls):
-    return _expect_raises(except_cls, check_context=True)
+def expect_raises(except_cls, check_context=True):
+    return _expect_raises(except_cls, check_context=check_context)
 
 
-def expect_raises_message(except_cls, msg):
-    return _expect_raises(except_cls, msg=msg, check_context=True)
+def expect_raises_message(except_cls, msg, check_context=True):
+    return _expect_raises(except_cls, msg=msg, check_context=check_context)
 
 
 class AssertsCompiledSQL(object):
index 21bacfca2fc3381db599f80e315e447a355d21bb..094d1ea94bf3118c5b735880d0068e05baab70a4 100644 (file)
@@ -7,6 +7,7 @@ from . import engines
 from .. import exc
 from ..engine import url as sa_url
 from ..util import compat
+from ..util import parse_qsl
 
 log = logging.getLogger(__name__)
 
@@ -85,7 +86,7 @@ def generate_db_urls(db_urls, extra_drivers):
         --dburi postgresql://db1  \
         --dburi postgresql://db2  \
         --dburi postgresql://db2  \
-        --dbdriver=psycopg2 --dbdriver=asyncpg
+        --dbdriver=psycopg2 --dbdriver=asyncpg?async_fallback=true
 
     Noting that the default postgresql driver is psycopg2.  the output
     would be::
@@ -139,21 +140,34 @@ def _generate_driver_urls(url, extra_drivers):
     main_driver = url.get_driver_name()
     extra_drivers.discard(main_driver)
 
-    url = generate_driver_url(url, main_driver)
+    url = generate_driver_url(url, main_driver, {})
     yield str(url)
 
     for drv in list(extra_drivers):
-        new_url = generate_driver_url(url, drv)
+
+        if "?" in drv:
+
+            driver_only, query_str = drv.split("?", 1)
+
+            query = parse_qsl(query_str)
+        else:
+            driver_only = drv
+            query = {}
+
+        new_url = generate_driver_url(url, driver_only, query)
         if new_url:
             extra_drivers.remove(drv)
+
             yield str(new_url)
 
 
 @register.init
-def generate_driver_url(url, driver):
+def generate_driver_url(url, driver, query):
     backend = url.get_backend_name()
     new_url = copy.copy(url)
+    new_url.query = dict(new_url.query)
     new_url.drivername = "%s+%s" % (backend, driver)
+    new_url.query.update(query)
     try:
         new_url.get_dialect()
     except exc.NoSuchModuleError:
index 5e6ac1eabd913b216326b10ad5649aa08d805aa9..6a390231bbac8c3e31c715fff5efc1cd23834658 100644 (file)
@@ -34,7 +34,6 @@ from ... import testing
 from ... import Text
 from ... import Time
 from ... import TIMESTAMP
-from ... import type_coerce
 from ... import TypeDecorator
 from ... import Unicode
 from ... import UnicodeText
@@ -1161,37 +1160,6 @@ class JSONStringCastIndexTest(_LiteralRoundTripFixture, fixtures.TablesTest):
             and_(name == "r6", cast(col["b"], String) == '"some value"'), "r6"
         )
 
-    def test_crit_against_int_basic(self):
-        name = self.tables.data_table.c.name
-        col = self.tables.data_table.c["data"]
-
-        self._test_index_criteria(
-            and_(name == "r6", cast(col["a"], String) == "5"), "r6"
-        )
-
-    def _dont_test_crit_against_string_coerce_type(self):
-        name = self.tables.data_table.c.name
-        col = self.tables.data_table.c["data"]
-
-        self._test_index_criteria(
-            and_(
-                name == "r6",
-                cast(col["b"], String) == type_coerce("some value", JSON),
-            ),
-            "r6",
-            test_literal=False,
-        )
-
-    def _dont_test_crit_against_int_coerce_type(self):
-        name = self.tables.data_table.c.name
-        col = self.tables.data_table.c["data"]
-
-        self._test_index_criteria(
-            and_(name == "r6", cast(col["a"], String) == type_coerce(5, JSON)),
-            "r6",
-            test_literal=False,
-        )
-
 
 __all__ = (
     "UnicodeVarcharTest",
index d15e3a843c10345d95d2f0dfe7ed5fbb302f9170..6eaa3295b9aa09882ea03f3d86083c6ce665742b 100644 (file)
@@ -982,7 +982,7 @@ $$ LANGUAGE plpgsql;
             t = Table("t", m, Column("c", type_, primary_key=True))
 
             if version:
-                dialect = postgresql.dialect()
+                dialect = testing.db.dialect.__class__()
                 dialect._get_server_version_info = mock.Mock(
                     return_value=version
                 )
index 0d02ab3e7bb60e284ad77cf84242672e8f78c9f8..f09f0f1e127e494c4d4b91cad77ceda7292b94d4 100644 (file)
@@ -579,10 +579,6 @@ class RawExecuteTest(fixtures.TablesTest):
             Column("user_name", VARCHAR(20)),
         )
 
-    @testing.fails_on(
-        "postgresql+pg8000",
-        "pg8000 still doesn't allow single paren without params",
-    )
     def test_no_params_option(self, connection):
         stmt = (
             "SELECT '%'"
index 89d5c634866292b5f2ee98b99a75c1047ef0066e..fd42224ebb5c7db829edb04ee28ea5576581b5df 100644 (file)
@@ -77,10 +77,6 @@ class ExecuteTest(fixtures.TablesTest):
             Column("user_name", VARCHAR(20)),
         )
 
-    @testing.fails_on(
-        "postgresql+pg8000",
-        "pg8000 still doesn't allow single paren without params",
-    )
     def test_no_params_option(self):
         stmt = (
             "SELECT '%'"
index 48eb485cb7a76f8e2d632af137284c953c049f22..53a5ec6f4de44d3cfcfd2bcb54cb552ddb5e7382 100644 (file)
@@ -17,6 +17,7 @@ from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import assert_raises_message_context_ok
 from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises
 from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
@@ -1312,12 +1313,10 @@ class PrePingRealTest(fixtures.TestBase):
         eq_(conn.execute(select(1)).scalar(), 1)
         conn.close()
 
-        def exercise_stale_connection():
+        with expect_raises(engine.dialect.dbapi.Error, check_context=False):
             curs = stale_connection.cursor()
             curs.execute("select 1")
 
-        assert_raises(engine.dialect.dbapi.Error, exercise_stale_connection)
-
     def test_pre_ping_db_stays_shutdown(self):
         engine = engines.reconnecting_engine(options={"pool_pre_ping": True})
 
index 99a3605658d8f05125f67e389c90ba7e7eabf878..99a6f5a3b4e1e1815972f5b432a3c7c298f7507e 100644 (file)
@@ -1017,7 +1017,7 @@ class DefaultRequirements(SuiteRequirements):
 
     @property
     def json_array_indexes(self):
-        return self.json_type + fails_if("+pg8000")
+        return self.json_type
 
     @property
     def datetime_literals(self):
@@ -1209,20 +1209,6 @@ class DefaultRequirements(SuiteRequirements):
                     "Firebird still has FP inaccuracy even "
                     "with only four decimal places",
                 ),
-                (
-                    "postgresql+pg8000",
-                    None,
-                    None,
-                    "postgresql+pg8000 has FP inaccuracy even with "
-                    "only four decimal places ",
-                ),
-                (
-                    "postgresql+psycopg2cffi",
-                    None,
-                    None,
-                    "postgresql+psycopg2cffi has FP inaccuracy even with "
-                    "only four decimal places ",
-                ),
             ]
         )
 
@@ -1253,7 +1239,7 @@ class DefaultRequirements(SuiteRequirements):
 
     @property
     def duplicate_key_raises_integrity_error(self):
-        return fails_on("postgresql+pg8000")
+        return exclusions.open()
 
     def _has_pg_extension(self, name):
         def check(config):
diff --git a/tox.ini b/tox.ini
index 92fe031724d0a2f77bd71e01797ce1235f1a75b3..ac95dc42c67ccb95c6b898658a13073b1b3d268f 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -21,7 +21,8 @@ deps=pytest!=3.9.1,!=3.9.2
      mock; python_version < '3.3'
      importlib_metadata; python_version < '3.8'
      postgresql: .[postgresql]
-     postgresql: .[postgresql_asyncpg]
+     postgresql: .[postgresql_asyncpg]; python_version >= '3'
+     postgresql: .[postgresql_pg8000]; python_version >= '3'
      mysql: .[mysql]
      mysql: .[pymysql]
      oracle: .[oracle]
@@ -66,11 +67,12 @@ setenv=
     sqlite_file: SQLITE={env:TOX_SQLITE_FILE:--db sqlite_file}
 
     postgresql: POSTGRESQL={env:TOX_POSTGRESQL:--db postgresql}
-    postgresql: EXTRA_PG_DRIVERS={env:EXTRA_PG_DRIVERS:--dbdriver psycopg2 --dbdriver asyncpg}
+    py3{,5,6,7,8,9,10,11}-postgresql: EXTRA_PG_DRIVERS={env:EXTRA_PG_DRIVERS:--dbdriver psycopg2 --dbdriver asyncpg?async_fallback=true --dbdriver pg8000}
 
     mysql: MYSQL={env:TOX_MYSQL:--db mysql}
     mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql}
 
+
     mssql: MSSQL={env:TOX_MSSQL:--db mssql}
 
     oracle,mssql,sqlite_file: IDENTS=--write-idents db_idents.txt