From: Federico Caselli Date: Tue, 2 Aug 2022 10:27:57 +0000 (+0200) Subject: JSONPATH type can be used in casts in PostgreSQL X-Git-Tag: rel_2_0_0b1~113^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2e7117ab1b584e678380c70625ad1331cea551d0;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git JSONPATH type can be used in casts in PostgreSQL Introduced the type :class:`_postgresql.JSONPATH` that can be used in cast expressions. This is required by some PostgreSQL dialects when using functions such as ``jsonb_path_exists`` or ``jsonb_path_match`` that accept a ``jsonpath`` as input. Fixes: #8216 Change-Id: I3e7337eab91680cab1604e1f3058854a0a19c5be --- diff --git a/doc/build/changelog/unreleased_20/8216.rst b/doc/build/changelog/unreleased_20/8216.rst new file mode 100644 index 0000000000..c213e37cd5 --- /dev/null +++ b/doc/build/changelog/unreleased_20/8216.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: postgresql, schema + :tickets: 8216 + + Introduced the type :class:`_postgresql.JSONPATH` that can be used + in cast expressions. This is required by some PostgreSQL dialects + when using functions such as ``jsonb_path_exists`` or + ``jsonb_path_match`` that accept a ``jsonpath`` as input. + + .. seealso:: + + :ref:`postgresql_json_types` - PostgreSQL JSON types. \ No newline at end of file diff --git a/doc/build/dialects/postgresql.rst b/doc/build/dialects/postgresql.rst index 9ec6ee9612..f6b0073802 100644 --- a/doc/build/dialects/postgresql.rst +++ b/doc/build/dialects/postgresql.rst @@ -20,6 +20,8 @@ as well as array literals: * :class:`_postgresql.aggregate_order_by` - helper for PG's ORDER BY aggregate function syntax. +.. _postgresql_json_types: + JSON Types ---------- @@ -31,6 +33,8 @@ operators: * :class:`_postgresql.JSONB` +* :class:`_postgresql.JSONPATH` + HSTORE Type ----------- @@ -362,6 +366,8 @@ construction arguments, are as follows: .. autoclass:: JSONB :members: +.. autoclass:: JSONPATH + .. autoclass:: MACADDR .. autoclass:: MONEY diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py index 104077a171..8dbee1f7f4 100644 --- a/lib/sqlalchemy/dialects/postgresql/__init__.py +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -41,6 +41,7 @@ from .hstore import HSTORE from .hstore import hstore from .json import JSON from .json import JSONB +from .json import JSONPATH from .named_types import CreateDomainType from .named_types import CreateEnumType from .named_types import DropDomainType @@ -128,6 +129,7 @@ __all__ = ( "TSTZMULTIRANGE", "JSON", "JSONB", + "JSONPATH", "Any", "All", "DropEnumType", diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 38f8fddee6..6888959f0a 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -122,7 +122,6 @@ client using this setting passed to :func:`_asyncio.create_async_engine`:: from __future__ import annotations import collections -import collections.abc as collections_abc import decimal import json as _py_json import re @@ -231,9 +230,15 @@ class AsyncpgJSONStrIndexType(sqltypes.JSON.JSONStrIndexType): class AsyncpgJSONPathType(json.JSONPathType): def bind_processor(self, dialect): def process(value): - assert isinstance(value, collections_abc.Sequence) - tokens = [str(elem) for elem in value] - return tokens + if isinstance(value, str): + # If it's already a string assume that it's in json path + # format. This allows using cast with json paths literals + return value + elif value: + tokens = [str(elem) for elem in value] + return tokens + else: + return [] return process diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 35a09cfa76..75fa3d7c77 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1512,7 +1512,7 @@ colspecs = { sqltypes.ARRAY: _array.ARRAY, sqltypes.Interval: INTERVAL, sqltypes.Enum: ENUM, - sqltypes.JSON.JSONPathType: _json.JSONPathType, + sqltypes.JSON.JSONPathType: _json.JSONPATH, sqltypes.JSON: _json.JSON, UUID: PGUuid, } @@ -2503,6 +2503,12 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): count=1, ) + def visit_json_path(self, type_, **kw): + return self.visit_JSONPATH(type_, **kw) + + def visit_JSONPATH(self, type_, **kw): + return "JSONPATH" + class PGIdentifierPreparer(compiler.IdentifierPreparer): diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py index 8763a0ca20..a8b03bd482 100644 --- a/lib/sqlalchemy/dialects/postgresql/json.py +++ b/lib/sqlalchemy/dialects/postgresql/json.py @@ -6,7 +6,6 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors -import collections.abc as collections_abc from ... import types as sqltypes from ...sql import operators @@ -68,31 +67,47 @@ CONTAINED_BY = operators.custom_op( class JSONPathType(sqltypes.JSON.JSONPathType): - def bind_processor(self, dialect): - super_proc = self.string_bind_processor(dialect) - + def _processor(self, dialect, super_proc): def process(value): - assert isinstance(value, collections_abc.Sequence) - tokens = [str(elem) for elem in value] - value = "{%s}" % (", ".join(tokens)) + if isinstance(value, str): + # If it's already a string assume that it's in json path + # format. This allows using cast with json paths literals + return value + elif value: + # If it's already a string assume that it's in json path + # format. This allows using cast with json paths literals + value = "{%s}" % (", ".join(map(str, value))) + else: + value = "{}" if super_proc: value = super_proc(value) return value return process + def bind_processor(self, dialect): + return self._processor(dialect, self.string_bind_processor(dialect)) + def literal_processor(self, dialect): - super_proc = self.string_literal_processor(dialect) + return self._processor(dialect, self.string_literal_processor(dialect)) - def process(value): - assert isinstance(value, collections_abc.Sequence) - tokens = [str(elem) for elem in value] - value = "{%s}" % (", ".join(tokens)) - if super_proc: - value = super_proc(value) - return value - return process +class JSONPATH(JSONPathType): + """JSON Path Type. + + This is usually required to cast literal values to json path when using + json search like function, such as ``jsonb_path_query_array`` or + ``jsonb_path_exists``:: + + stmt = sa.select( + sa.func.jsonb_path_query_array( + table.c.jsonb_col, cast("$.address.id", JSONPATH) + ) + ) + + """ + + __visit_name__ = "JSONPATH" class JSON(sqltypes.JSON): diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index de833cd893..f04d583e11 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -2419,6 +2419,8 @@ class JSON(Indexable, TypeEngine[Any]): """ + __visit_name__ = "json_path" + class Comparator(Indexable.Comparator[_T], Concatenable.Comparator[_T]): """Define comparison operations for :class:`_types.JSON`.""" diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index 5e5c4f9bdb..67e54e4f51 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -41,6 +41,8 @@ from sqlalchemy.dialects.postgresql import array_agg as pg_array_agg from sqlalchemy.dialects.postgresql import DOMAIN from sqlalchemy.dialects.postgresql import ExcludeConstraint from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.dialects.postgresql import JSONPATH from sqlalchemy.dialects.postgresql import TSRANGE from sqlalchemy.dialects.postgresql.base import PGDialect from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2 @@ -2354,6 +2356,18 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "WHERE data.group_id = summary.group_id)", ) + @testing.combinations(JSONB.JSONPathType, JSONPATH) + def test_json_path(self, type_): + data = table("data", column("id", Integer), column("x", JSONB)) + stmt = select( + func.jsonb_path_exists(data.c.x, cast("$.data.w", type_)) + ) + self.assert_compile( + stmt, + "SELECT jsonb_path_exists(data.x, CAST(%(param_1)s AS JSONPATH)) " + "AS jsonb_path_exists_1 FROM data", + ) + class InsertOnConflictTest(fixtures.TablesTest, AssertsCompiledSQL): __dialect__ = postgresql.dialect() diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index b4c19238d3..8f655362e8 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -4496,8 +4496,11 @@ class JSONRoundTripTest(fixtures.TablesTest): Column("nulldata", cls.data_type(none_as_null=True)), ) + @property + def data_table(self): + return self.tables.data_table + def _fixture_data(self, connection): - data_table = self.tables.data_table data = [ {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}}, @@ -4507,23 +4510,23 @@ class JSONRoundTripTest(fixtures.TablesTest): {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2", "k3": 5}}, {"name": "r6", "data": {"k1": {"r6v1": {"subr": [1, 2, 3]}}}}, ] - connection.execute(data_table.insert(), data) + connection.execute(self.data_table.insert(), data) return data def _assert_data(self, compare, conn, column="data"): - col = self.tables.data_table.c[column] + col = self.data_table.c[column] data = conn.execute( - select(col).order_by(self.tables.data_table.c.name) + select(col).order_by(self.data_table.c.name) ).fetchall() eq_([d for d, in data], compare) def _assert_column_is_NULL(self, conn, column="data"): - col = self.tables.data_table.c[column] + col = self.data_table.c[column] data = conn.execute(select(col).where(col.is_(null()))).fetchall() eq_([d for d, in data], [None]) def _assert_column_is_JSON_NULL(self, conn, column="data"): - col = self.tables.data_table.c[column] + col = self.data_table.c[column] data = conn.execute( select(col).where(cast(col, String) == "null") ).fetchall() @@ -4539,7 +4542,7 @@ class JSONRoundTripTest(fixtures.TablesTest): argnames="key", ) def test_indexed_special_keys(self, connection, key): - data_table = self.tables.data_table + data_table = self.data_table data_element = {key: "some value"} connection.execute( @@ -4563,27 +4566,27 @@ class JSONRoundTripTest(fixtures.TablesTest): def test_insert(self, connection): connection.execute( - self.tables.data_table.insert(), + self.data_table.insert(), {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}}, ) self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], connection) def test_insert_nulls(self, connection): connection.execute( - self.tables.data_table.insert(), {"name": "r1", "data": null()} + self.data_table.insert(), {"name": "r1", "data": null()} ) self._assert_data([None], connection) def test_insert_none_as_null(self, connection): connection.execute( - self.tables.data_table.insert(), + self.data_table.insert(), {"name": "r1", "nulldata": None}, ) self._assert_column_is_NULL(connection, column="nulldata") def test_insert_nulljson_into_none_as_null(self, connection): connection.execute( - self.tables.data_table.insert(), + self.data_table.insert(), {"name": "r1", "nulldata": JSON.NULL}, ) self._assert_column_is_JSON_NULL(connection, column="nulldata") @@ -4654,9 +4657,8 @@ class JSONRoundTripTest(fixtures.TablesTest): def test_query_returned_as_text(self, connection): self._fixture_data(connection) - data_table = self.tables.data_table result = connection.execute( - select(data_table.c.data["k1"].astext) + select(self.data_table.c.data["k1"].astext) ).first() assert isinstance(result[0], str) @@ -4795,6 +4797,21 @@ class JSONBRoundTripTest(JSONRoundTripTest): def test_unicode_round_trip(self, connection): super(JSONBRoundTripTest, self).test_unicode_round_trip(connection) + @testing.only_on("postgresql >= 12") + def test_cast_jsonpath(self, connection): + self._fixture_data(connection) + + def go(path, res): + q = select(func.count("*")).where( + func.jsonb_path_exists( + self.data_table.c.data, cast(path, JSONB.JSONPathType) + ) + ) + eq_(connection.scalar(q), res) + + go("$.k1.k2", 0) + go("$.k1.r6v1", 1) + class JSONBSuiteTest(suite.JSONTest): __requires__ = ("postgresql_jsonb",)