From: Federico Caselli Date: Mon, 21 Sep 2020 17:59:00 +0000 (+0200) Subject: Improve Asyncpg json handling X-Git-Tag: rel_1_4_0b1~78^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=219c717e2357439e719464add9f86dc2f40ae667;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Improve Asyncpg json handling Set default type codec for ``json`` and ``jsonb`` types when using the asyncpg driver. By default asyncpg will not decode them and return strings instead. Fixes: #5584 Change-Id: I41348eff8096ccf87b952d7e797c0694c6c4b5c4 --- diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 6fa1dd78be..1f988153c5 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -36,11 +36,21 @@ in conjunction with :func:`_sa.craete_engine`:: .. versionadded:: 1.4 +.. note:: + + By default asyncpg does not decode the ``json`` and ``jsonb`` types and + returns them as strings. SQLAlchemy sets default type decoder for ``json`` + and ``jsonb`` types using the python builtin ``json.loads`` function. + The json implementation used can be changed by setting the attribute + ``json_deserializer`` when creating the engine with + :func:`create_engine` or :func:`create_async_engine`. + """ # noqa import collections import decimal import itertools +import json as _py_json import re from . import json @@ -123,11 +133,17 @@ class AsyncpgJSON(json.JSON): def get_dbapi_type(self, dbapi): return dbapi.JSON + def result_processor(self, dialect, coltype): + return None + class AsyncpgJSONB(json.JSONB): def get_dbapi_type(self, dbapi): return dbapi.JSONB + def result_processor(self, dialect, coltype): + return None + class AsyncpgJSONIndexType(sqltypes.JSON.JSONIndexType): def get_dbapi_type(self, dbapi): @@ -481,17 +497,6 @@ class AsyncAdapt_asyncpg_connection: self.deferrable = False self._transaction = None self._started = False - self.await_(self._setup_type_codecs()) - - async def _setup_type_codecs(self): - """set up type decoders at the asyncpg level. - - these are set_type_codec() calls to normalize - There was a tentative decoder for the "char" datatype here - to have it return strings however this type is actually a binary - type that other drivers are likely mis-interpreting. - - """ def _handle_exception(self, error): if not isinstance(error, AsyncAdapt_asyncpg_dbapi.Error): @@ -781,5 +786,56 @@ class PGDialect_asyncpg(PGDialect): e, self.dbapi.InterfaceError ) and "connection is closed" in str(e) + def on_connect(self): + super_connect = super(PGDialect_asyncpg, self).on_connect() + + def _jsonb_encoder(str_value): + # \x01 is the prefix for jsonb used by PostgreSQL. + # asyncpg requires it when format='binary' + return b"\x01" + str_value.encode() + + deserializer = self._json_deserializer or _py_json.loads + + def _json_decoder(bin_value): + return deserializer(bin_value.decode()) + + def _jsonb_decoder(bin_value): + # the byte is the \x01 prefix for jsonb used by PostgreSQL. + # asyncpg returns it when format='binary' + return deserializer(bin_value[1:].decode()) + + async def _setup_type_codecs(conn): + """set up type decoders at the asyncpg level. + + these are set_type_codec() calls to normalize + There was a tentative decoder for the "char" datatype here + to have it return strings however this type is actually a binary + type that other drivers are likely mis-interpreting. + + See https://github.com/MagicStack/asyncpg/issues/623 for reference + on why it's set up this way. + """ + await conn._connection.set_type_codec( + "json", + encoder=str.encode, + decoder=_json_decoder, + schema="pg_catalog", + format="binary", + ) + await conn._connection.set_type_codec( + "jsonb", + encoder=_jsonb_encoder, + decoder=_jsonb_decoder, + schema="pg_catalog", + format="binary", + ) + + def connect(conn): + conn.await_(_setup_type_codecs(conn)) + if super_connect is not None: + super_connect(conn) + + return connect + dialect = PGDialect_asyncpg diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index f13f251326..5def5aa5b7 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -1926,16 +1926,69 @@ class ArrayJSON(fixtures.TestBase): connection.execute( tbl.insert(), [ - {"json_col": ["foo"]}, - {"json_col": [{"foo": "bar"}, [1]]}, - {"json_col": [None]}, + {"id": 1, "json_col": ["foo"]}, + {"id": 2, "json_col": [{"foo": "bar"}, [1]]}, + {"id": 3, "json_col": [None]}, + {"id": 4, "json_col": [42]}, + {"id": 5, "json_col": [True]}, + {"id": 6, "json_col": None}, ], ) sel = select(tbl.c.json_col).order_by(tbl.c.id) eq_( connection.execute(sel).fetchall(), - [(["foo"],), ([{"foo": "bar"}, [1]],), ([None],)], + [ + (["foo"],), + ([{"foo": "bar"}, [1]],), + ([None],), + ([42],), + ([True],), + (None,), + ], + ) + + eq_( + connection.exec_driver_sql( + """select json_col::text = array['"foo"']::json[]::text""" + " from json_table where id = 1" + ).scalar(), + True, + ) + eq_( + connection.exec_driver_sql( + "select json_col::text = " + """array['{"foo": "bar"}', '[1]']::json[]::text""" + " from json_table where id = 2" + ).scalar(), + True, + ) + eq_( + connection.exec_driver_sql( + """select json_col::text = array['null']::json[]::text""" + " from json_table where id = 3" + ).scalar(), + True, + ) + eq_( + connection.exec_driver_sql( + """select json_col::text = array['42']::json[]::text""" + " from json_table where id = 4" + ).scalar(), + True, + ) + eq_( + connection.exec_driver_sql( + """select json_col::text = array['true']::json[]::text""" + " from json_table where id = 5" + ).scalar(), + True, + ) + eq_( + connection.exec_driver_sql( + "select json_col is null from json_table where id = 6" + ).scalar(), + True, ) @@ -3127,16 +3180,18 @@ class JSONRoundTripTest(fixtures.TablesTest): def _fixture_data(self, engine): data_table = self.tables.data_table + + data = [ + {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}}, + {"name": "r2", "data": {"k1": "r2v1", "k2": "r2v2"}}, + {"name": "r3", "data": {"k1": "r3v1", "k2": "r3v2"}}, + {"name": "r4", "data": {"k1": "r4v1", "k2": "r4v2"}}, + {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2", "k3": 5}}, + {"name": "r6", "data": {"k1": {"r6v1": {"subr": [1, 2, 3]}}}}, + ] with engine.begin() as conn: - conn.execute( - data_table.insert(), - {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}}, - {"name": "r2", "data": {"k1": "r2v1", "k2": "r2v2"}}, - {"name": "r3", "data": {"k1": "r3v1", "k2": "r3v2"}}, - {"name": "r4", "data": {"k1": "r4v1", "k2": "r4v2"}}, - {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2", "k3": 5}}, - {"name": "r6", "data": {"k1": {"r6v1": {"subr": [1, 2, 3]}}}}, - ) + conn.execute(data_table.insert(), data) + return data def _assert_data(self, compare, conn, column="data"): col = self.tables.data_table.c[column] @@ -3357,6 +3412,17 @@ class JSONRoundTripTest(fixtures.TablesTest): ("null", None), ) + def test_literal(self, connection): + exp = self._fixture_data(testing.db) + result = connection.exec_driver_sql( + "select data from data_table order by name" + ) + res = list(result) + eq_(len(res), len(exp)) + for row, expected in zip(res, exp): + eq_(row[0], expected["data"]) + result.close() + class JSONBTest(JSONTest): def setup(self):