]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Improve Asyncpg json handling
authorFederico Caselli <cfederico87@gmail.com>
Mon, 21 Sep 2020 17:59:00 +0000 (19:59 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Wed, 23 Sep 2020 19:29:56 +0000 (21:29 +0200)
Set default type codec for ``json`` and ``jsonb`` types when using
the asyncpg driver. By default asyncpg will not decode them and return
strings instead.

Fixes: #5584
Change-Id: I41348eff8096ccf87b952d7e797c0694c6c4b5c4

lib/sqlalchemy/dialects/postgresql/asyncpg.py
test/dialect/postgresql/test_types.py

index 6fa1dd78beee228c8ade143015083e13639ec3c7..1f988153c5ea548b9f2b182b4a1225da76784c4b 100644 (file)
@@ -36,11 +36,21 @@ in conjunction with :func:`_sa.craete_engine`::
 
 .. versionadded:: 1.4
 
+.. note::
+
+    By default asyncpg does not decode the ``json`` and ``jsonb`` types and
+    returns them as strings. SQLAlchemy sets default type decoder for ``json``
+    and ``jsonb`` types using the python builtin ``json.loads`` function.
+    The json implementation used can be changed by setting the attribute
+    ``json_deserializer`` when creating the engine with
+    :func:`create_engine` or :func:`create_async_engine`.
+
 """  # noqa
 
 import collections
 import decimal
 import itertools
+import json as _py_json
 import re
 
 from . import json
@@ -123,11 +133,17 @@ class AsyncpgJSON(json.JSON):
     def get_dbapi_type(self, dbapi):
         return dbapi.JSON
 
+    def result_processor(self, dialect, coltype):
+        return None
+
 
 class AsyncpgJSONB(json.JSONB):
     def get_dbapi_type(self, dbapi):
         return dbapi.JSONB
 
+    def result_processor(self, dialect, coltype):
+        return None
+
 
 class AsyncpgJSONIndexType(sqltypes.JSON.JSONIndexType):
     def get_dbapi_type(self, dbapi):
@@ -481,17 +497,6 @@ class AsyncAdapt_asyncpg_connection:
         self.deferrable = False
         self._transaction = None
         self._started = False
-        self.await_(self._setup_type_codecs())
-
-    async def _setup_type_codecs(self):
-        """set up type decoders at the asyncpg level.
-
-        these are set_type_codec() calls to normalize
-        There was a tentative decoder for the "char" datatype here
-        to have it return strings however this type is actually a binary
-        type that other drivers are likely mis-interpreting.
-
-        """
 
     def _handle_exception(self, error):
         if not isinstance(error, AsyncAdapt_asyncpg_dbapi.Error):
@@ -781,5 +786,56 @@ class PGDialect_asyncpg(PGDialect):
                 e, self.dbapi.InterfaceError
             ) and "connection is closed" in str(e)
 
+    def on_connect(self):
+        super_connect = super(PGDialect_asyncpg, self).on_connect()
+
+        def _jsonb_encoder(str_value):
+            # \x01 is the prefix for jsonb used by PostgreSQL.
+            # asyncpg requires it when format='binary'
+            return b"\x01" + str_value.encode()
+
+        deserializer = self._json_deserializer or _py_json.loads
+
+        def _json_decoder(bin_value):
+            return deserializer(bin_value.decode())
+
+        def _jsonb_decoder(bin_value):
+            # the byte is the \x01 prefix for jsonb used by PostgreSQL.
+            # asyncpg returns it when format='binary'
+            return deserializer(bin_value[1:].decode())
+
+        async def _setup_type_codecs(conn):
+            """set up type decoders at the asyncpg level.
+
+            these are set_type_codec() calls to normalize
+            There was a tentative decoder for the "char" datatype here
+            to have it return strings however this type is actually a binary
+            type that other drivers are likely mis-interpreting.
+
+            See https://github.com/MagicStack/asyncpg/issues/623 for reference
+            on why it's set up this way.
+            """
+            await conn._connection.set_type_codec(
+                "json",
+                encoder=str.encode,
+                decoder=_json_decoder,
+                schema="pg_catalog",
+                format="binary",
+            )
+            await conn._connection.set_type_codec(
+                "jsonb",
+                encoder=_jsonb_encoder,
+                decoder=_jsonb_decoder,
+                schema="pg_catalog",
+                format="binary",
+            )
+
+        def connect(conn):
+            conn.await_(_setup_type_codecs(conn))
+            if super_connect is not None:
+                super_connect(conn)
+
+        return connect
+
 
 dialect = PGDialect_asyncpg
index f13f251326140cdd82c6038634c75dd52eb99567..5def5aa5b77234c3ffd33485b1c7731be7ee9294 100644 (file)
@@ -1926,16 +1926,69 @@ class ArrayJSON(fixtures.TestBase):
         connection.execute(
             tbl.insert(),
             [
-                {"json_col": ["foo"]},
-                {"json_col": [{"foo": "bar"}, [1]]},
-                {"json_col": [None]},
+                {"id": 1, "json_col": ["foo"]},
+                {"id": 2, "json_col": [{"foo": "bar"}, [1]]},
+                {"id": 3, "json_col": [None]},
+                {"id": 4, "json_col": [42]},
+                {"id": 5, "json_col": [True]},
+                {"id": 6, "json_col": None},
             ],
         )
 
         sel = select(tbl.c.json_col).order_by(tbl.c.id)
         eq_(
             connection.execute(sel).fetchall(),
-            [(["foo"],), ([{"foo": "bar"}, [1]],), ([None],)],
+            [
+                (["foo"],),
+                ([{"foo": "bar"}, [1]],),
+                ([None],),
+                ([42],),
+                ([True],),
+                (None,),
+            ],
+        )
+
+        eq_(
+            connection.exec_driver_sql(
+                """select json_col::text = array['"foo"']::json[]::text"""
+                " from json_table where id = 1"
+            ).scalar(),
+            True,
+        )
+        eq_(
+            connection.exec_driver_sql(
+                "select json_col::text = "
+                """array['{"foo": "bar"}', '[1]']::json[]::text"""
+                " from json_table where id = 2"
+            ).scalar(),
+            True,
+        )
+        eq_(
+            connection.exec_driver_sql(
+                """select json_col::text = array['null']::json[]::text"""
+                " from json_table where id = 3"
+            ).scalar(),
+            True,
+        )
+        eq_(
+            connection.exec_driver_sql(
+                """select json_col::text = array['42']::json[]::text"""
+                " from json_table where id = 4"
+            ).scalar(),
+            True,
+        )
+        eq_(
+            connection.exec_driver_sql(
+                """select json_col::text = array['true']::json[]::text"""
+                " from json_table where id = 5"
+            ).scalar(),
+            True,
+        )
+        eq_(
+            connection.exec_driver_sql(
+                "select json_col is null from json_table where id = 6"
+            ).scalar(),
+            True,
         )
 
 
@@ -3127,16 +3180,18 @@ class JSONRoundTripTest(fixtures.TablesTest):
 
     def _fixture_data(self, engine):
         data_table = self.tables.data_table
+
+        data = [
+            {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
+            {"name": "r2", "data": {"k1": "r2v1", "k2": "r2v2"}},
+            {"name": "r3", "data": {"k1": "r3v1", "k2": "r3v2"}},
+            {"name": "r4", "data": {"k1": "r4v1", "k2": "r4v2"}},
+            {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2", "k3": 5}},
+            {"name": "r6", "data": {"k1": {"r6v1": {"subr": [1, 2, 3]}}}},
+        ]
         with engine.begin() as conn:
-            conn.execute(
-                data_table.insert(),
-                {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
-                {"name": "r2", "data": {"k1": "r2v1", "k2": "r2v2"}},
-                {"name": "r3", "data": {"k1": "r3v1", "k2": "r3v2"}},
-                {"name": "r4", "data": {"k1": "r4v1", "k2": "r4v2"}},
-                {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2", "k3": 5}},
-                {"name": "r6", "data": {"k1": {"r6v1": {"subr": [1, 2, 3]}}}},
-            )
+            conn.execute(data_table.insert(), data)
+        return data
 
     def _assert_data(self, compare, conn, column="data"):
         col = self.tables.data_table.c[column]
@@ -3357,6 +3412,17 @@ class JSONRoundTripTest(fixtures.TablesTest):
             ("null", None),
         )
 
+    def test_literal(self, connection):
+        exp = self._fixture_data(testing.db)
+        result = connection.exec_driver_sql(
+            "select data from data_table order by name"
+        )
+        res = list(result)
+        eq_(len(res), len(exp))
+        for row, expected in zip(res, exp):
+            eq_(row[0], expected["data"])
+        result.close()
+
 
 class JSONBTest(JSONTest):
     def setup(self):