]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Improve array support on pg8000
authorFederico Caselli <cfederico87@gmail.com>
Tue, 16 Mar 2021 23:27:18 +0000 (00:27 +0100)
committerFederico Caselli <cfederico87@gmail.com>
Wed, 27 Oct 2021 20:10:52 +0000 (22:10 +0200)
References: #6023

Change-Id: I0f6cbc34b3c0bfc0b8c86b3ebe4531e23039b6c0

doc/build/changelog/unreleased_14/6023.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/array.py
lib/sqlalchemy/dialects/postgresql/pg8000.py
lib/sqlalchemy/sql/type_api.py
test/dialect/postgresql/test_types.py

diff --git a/doc/build/changelog/unreleased_14/6023.rst b/doc/build/changelog/unreleased_14/6023.rst
new file mode 100644 (file)
index 0000000..88d9777
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: postgresql, pg8000
+    :tickets: 7167
+
+    Improve array handling when using PostgreSQL with the
+    pg8000 dialect.
index 9659d31b93aafc367de9aef82b66f0c5f22cbc79..0cb574dacf7b5f52f34f652c38ff132378ed011b 100644 (file)
@@ -375,7 +375,7 @@ class ARRAY(sqltypes.ARRAY):
                 if value is None:
                     return value
                 # isinstance(value, util.string_types) is required to handle
-                # the case where a TypeDecorator for and Array of Enum is
+                # the case where a TypeDecorator for and Array of Enum is
                 # used like was required in sa < 1.3.17
                 return super_rp(
                     handle_raw_string(value)
index d42dd9560d94b846d6d7ae8ee73148b1061d1f41..a94f9dcdbb0057024428ef49c78a75df230491df 100644 (file)
@@ -93,6 +93,8 @@ import decimal
 import re
 from uuid import UUID as _python_UUID
 
+from .array import ARRAY as PGARRAY
+from .base import _ColonCast
 from .base import _DECIMAL_TYPES
 from .base import _FLOAT_TYPES
 from .base import _INT_TYPES
@@ -256,6 +258,11 @@ class _PGBoolean(sqltypes.Boolean):
         return dbapi.BOOLEAN
 
 
+class _PGARRAY(PGARRAY):
+    def bind_expression(self, bindvalue):
+        return _ColonCast(bindvalue, self)
+
+
 _server_side_id = util.counter()
 
 
@@ -384,6 +391,7 @@ class PGDialect_pg8000(PGDialect):
             sqltypes.SmallInteger: _PGSmallInteger,
             sqltypes.BigInteger: _PGBigInteger,
             sqltypes.Enum: _PGEnum,
+            sqltypes.ARRAY: _PGARRAY,
         },
     )
 
index acf88f0daf846eb2100d0f906fd0d4b146091e69..2a4688bcceb3c9e75c30966a12176e877cc39fcf 100644 (file)
@@ -633,7 +633,8 @@ class TypeEngine(Traversible):
         try:
             return dialect._type_memos[self]["impl"]
         except KeyError:
-            return self._dialect_info(dialect)["impl"]
+            pass
+        return self._dialect_info(dialect)["impl"]
 
     def _unwrapped_dialect_impl(self, dialect):
         """Return the 'unwrapped' dialect impl for this type.
index dd0a1be0f306e3963888a4f0a905a31a207bb23f..d1c0361e4f9a339f4cc3d7cf207c31693f8435e2 100644 (file)
@@ -1443,7 +1443,6 @@ class ArrayRoundTripTest(object):
 
     __only_on__ = "postgresql"
     __backend__ = True
-    __unsupported_on__ = ("postgresql+pg8000",)
 
     ARRAY = postgresql.ARRAY
 
@@ -1962,14 +1961,8 @@ class ArrayRoundTripTest(object):
             (sqltypes.Unicode, unicode_values),
             (postgresql.JSONB, json_values),
             (sqltypes.Boolean, lambda x: [False] + [True] * x),
-            (
-                sqltypes.LargeBinary,
-                binary_values,
-            ),
-            (
-                postgresql.BYTEA,
-                binary_values,
-            ),
+            (sqltypes.LargeBinary, binary_values),
+            (postgresql.BYTEA, binary_values),
             (
                 postgresql.INET,
                 lambda x: [
@@ -2047,6 +2040,7 @@ class ArrayRoundTripTest(object):
             (postgresql.ENUM(AnEnum), enum_values),
             (sqltypes.Enum(AnEnum, native_enum=True), enum_values),
             (sqltypes.Enum(AnEnum, native_enum=False), enum_values),
+            (postgresql.ENUM(AnEnum, native_enum=True), enum_values),
         ]
 
         if not exclude_json:
@@ -2057,6 +2051,22 @@ class ArrayRoundTripTest(object):
                 ]
             )
 
+        _pg8000_skip_types = {
+            postgresql.HSTORE,  # return not parsed returned as string
+        }
+        for i in range(len(elements)):
+            elem = elements[i]
+            if (
+                elem[0] in _pg8000_skip_types
+                or type(elem[0]) in _pg8000_skip_types
+            ):
+                elem += (
+                    testing.skip_if(
+                        "postgresql+pg8000", "type not supported by pg8000"
+                    ),
+                )
+                elements[i] = elem
+
         return testing.combinations_list(
             elements, argnames="type_,gen", id_="na"
         )