]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
JSONPATH type can be used in casts in PostgreSQL
authorFederico Caselli <cfederico87@gmail.com>
Tue, 2 Aug 2022 10:27:57 +0000 (12:27 +0200)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 17 Aug 2022 15:50:24 +0000 (11:50 -0400)
Introduced the type :class:`_postgresql.JSONPATH` that can be used
in cast expressions. This is required by some PostgreSQL dialects
when using functions such as ``jsonb_path_exists`` or
``jsonb_path_match`` that accept a ``jsonpath`` as input.

Fixes: #8216
Change-Id: I3e7337eab91680cab1604e1f3058854a0a19c5be

doc/build/changelog/unreleased_20/8216.rst [new file with mode: 0644]
doc/build/dialects/postgresql.rst
lib/sqlalchemy/dialects/postgresql/__init__.py
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/json.py
lib/sqlalchemy/sql/sqltypes.py
test/dialect/postgresql/test_compiler.py
test/dialect/postgresql/test_types.py

diff --git a/doc/build/changelog/unreleased_20/8216.rst b/doc/build/changelog/unreleased_20/8216.rst
new file mode 100644 (file)
index 0000000..c213e37
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: postgresql, schema
+    :tickets: 8216
+
+    Introduced the type :class:`_postgresql.JSONPATH` that can be used
+    in cast expressions. This is required by some PostgreSQL dialects
+    when using functions such as ``jsonb_path_exists`` or
+    ``jsonb_path_match`` that accept a ``jsonpath`` as input.
+
+    .. seealso::
+
+        :ref:`postgresql_json_types` - PostgreSQL JSON types.
\ No newline at end of file
index 9ec6ee9612fe17378c4357373aca8ce1a2225991..f6b0073802a7601bfca34c9e7b4e9ec71f5a64bb 100644 (file)
@@ -20,6 +20,8 @@ as well as array literals:
 * :class:`_postgresql.aggregate_order_by` - helper for PG's ORDER BY aggregate
   function syntax.
 
+.. _postgresql_json_types:
+
 JSON Types
 ----------
 
@@ -31,6 +33,8 @@ operators:
 
 * :class:`_postgresql.JSONB`
 
+* :class:`_postgresql.JSONPATH`
+
 HSTORE Type
 -----------
 
@@ -362,6 +366,8 @@ construction arguments, are as follows:
 .. autoclass:: JSONB
     :members:
 
+.. autoclass:: JSONPATH
+
 .. autoclass:: MACADDR
 
 .. autoclass:: MONEY
index 104077a171c70f6d4a48fdd0d8bedc851a376ca7..8dbee1f7f473d39e78ebfe41e26de494de50d226 100644 (file)
@@ -41,6 +41,7 @@ from .hstore import HSTORE
 from .hstore import hstore
 from .json import JSON
 from .json import JSONB
+from .json import JSONPATH
 from .named_types import CreateDomainType
 from .named_types import CreateEnumType
 from .named_types import DropDomainType
@@ -128,6 +129,7 @@ __all__ = (
     "TSTZMULTIRANGE",
     "JSON",
     "JSONB",
+    "JSONPATH",
     "Any",
     "All",
     "DropEnumType",
index 38f8fddee66dba62e60ba58117d65b6c378f09a4..6888959f0a27e1ae4fc1372a29a533cfa2e3755d 100644 (file)
@@ -122,7 +122,6 @@ client using this setting passed to :func:`_asyncio.create_async_engine`::
 from __future__ import annotations
 
 import collections
-import collections.abc as collections_abc
 import decimal
 import json as _py_json
 import re
@@ -231,9 +230,15 @@ class AsyncpgJSONStrIndexType(sqltypes.JSON.JSONStrIndexType):
 class AsyncpgJSONPathType(json.JSONPathType):
     def bind_processor(self, dialect):
         def process(value):
-            assert isinstance(value, collections_abc.Sequence)
-            tokens = [str(elem) for elem in value]
-            return tokens
+            if isinstance(value, str):
+                # If it's already a string assume that it's in json path
+                # format. This allows using cast with json paths literals
+                return value
+            elif value:
+                tokens = [str(elem) for elem in value]
+                return tokens
+            else:
+                return []
 
         return process
 
index 35a09cfa769d95e70f250a2445ff71137746a62e..75fa3d7c77dbc98c267d769b9684d59fe9028f47 100644 (file)
@@ -1512,7 +1512,7 @@ colspecs = {
     sqltypes.ARRAY: _array.ARRAY,
     sqltypes.Interval: INTERVAL,
     sqltypes.Enum: ENUM,
-    sqltypes.JSON.JSONPathType: _json.JSONPathType,
+    sqltypes.JSON.JSONPathType: _json.JSONPATH,
     sqltypes.JSON: _json.JSON,
     UUID: PGUuid,
 }
@@ -2503,6 +2503,12 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
             count=1,
         )
 
+    def visit_json_path(self, type_, **kw):
+        return self.visit_JSONPATH(type_, **kw)
+
+    def visit_JSONPATH(self, type_, **kw):
+        return "JSONPATH"
+
 
 class PGIdentifierPreparer(compiler.IdentifierPreparer):
 
index 8763a0ca20cda0e1fa1e47f187c0d722f70f6f8d..a8b03bd482ac75ca5ffef821135d675258331e9f 100644 (file)
@@ -6,7 +6,6 @@
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
 # mypy: ignore-errors
 
-import collections.abc as collections_abc
 
 from ... import types as sqltypes
 from ...sql import operators
@@ -68,31 +67,47 @@ CONTAINED_BY = operators.custom_op(
 
 
 class JSONPathType(sqltypes.JSON.JSONPathType):
-    def bind_processor(self, dialect):
-        super_proc = self.string_bind_processor(dialect)
-
+    def _processor(self, dialect, super_proc):
         def process(value):
-            assert isinstance(value, collections_abc.Sequence)
-            tokens = [str(elem) for elem in value]
-            value = "{%s}" % (", ".join(tokens))
+            if isinstance(value, str):
+                # If it's already a string assume that it's in json path
+                # format. This allows using cast with json paths literals
+                return value
+            elif value:
+                # If it's already a string assume that it's in json path
+                # format. This allows using cast with json paths literals
+                value = "{%s}" % (", ".join(map(str, value)))
+            else:
+                value = "{}"
             if super_proc:
                 value = super_proc(value)
             return value
 
         return process
 
+    def bind_processor(self, dialect):
+        return self._processor(dialect, self.string_bind_processor(dialect))
+
     def literal_processor(self, dialect):
-        super_proc = self.string_literal_processor(dialect)
+        return self._processor(dialect, self.string_literal_processor(dialect))
 
-        def process(value):
-            assert isinstance(value, collections_abc.Sequence)
-            tokens = [str(elem) for elem in value]
-            value = "{%s}" % (", ".join(tokens))
-            if super_proc:
-                value = super_proc(value)
-            return value
 
-        return process
+class JSONPATH(JSONPathType):
+    """JSON Path Type.
+
+    This is usually required to cast literal values to json path when using
+    json search like function, such as ``jsonb_path_query_array`` or
+    ``jsonb_path_exists``::
+
+        stmt = sa.select(
+            sa.func.jsonb_path_query_array(
+                table.c.jsonb_col, cast("$.address.id", JSONPATH)
+            )
+        )
+
+    """
+
+    __visit_name__ = "JSONPATH"
 
 
 class JSON(sqltypes.JSON):
index de833cd89369b690c432597690a018bd88173d8e..f04d583e11d5e23b86c01c3a18765f808f38f8c3 100644 (file)
@@ -2419,6 +2419,8 @@ class JSON(Indexable, TypeEngine[Any]):
 
         """
 
+        __visit_name__ = "json_path"
+
     class Comparator(Indexable.Comparator[_T], Concatenable.Comparator[_T]):
         """Define comparison operations for :class:`_types.JSON`."""
 
index 5e5c4f9bdb6b0db7398015a61a56f27eac0b8aa2..67e54e4f5176d795ed2509fc368cccd7deadc8fa 100644 (file)
@@ -41,6 +41,8 @@ from sqlalchemy.dialects.postgresql import array_agg as pg_array_agg
 from sqlalchemy.dialects.postgresql import DOMAIN
 from sqlalchemy.dialects.postgresql import ExcludeConstraint
 from sqlalchemy.dialects.postgresql import insert
+from sqlalchemy.dialects.postgresql import JSONB
+from sqlalchemy.dialects.postgresql import JSONPATH
 from sqlalchemy.dialects.postgresql import TSRANGE
 from sqlalchemy.dialects.postgresql.base import PGDialect
 from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
@@ -2354,6 +2356,18 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "WHERE data.group_id = summary.group_id)",
         )
 
+    @testing.combinations(JSONB.JSONPathType, JSONPATH)
+    def test_json_path(self, type_):
+        data = table("data", column("id", Integer), column("x", JSONB))
+        stmt = select(
+            func.jsonb_path_exists(data.c.x, cast("$.data.w", type_))
+        )
+        self.assert_compile(
+            stmt,
+            "SELECT jsonb_path_exists(data.x, CAST(%(param_1)s AS JSONPATH)) "
+            "AS jsonb_path_exists_1 FROM data",
+        )
+
 
 class InsertOnConflictTest(fixtures.TablesTest, AssertsCompiledSQL):
     __dialect__ = postgresql.dialect()
index b4c19238d324c0a8a27154d381f6111a48e11dd7..8f655362e848c38e78c6885e1d33e89dd363bb1b 100644 (file)
@@ -4496,8 +4496,11 @@ class JSONRoundTripTest(fixtures.TablesTest):
             Column("nulldata", cls.data_type(none_as_null=True)),
         )
 
+    @property
+    def data_table(self):
+        return self.tables.data_table
+
     def _fixture_data(self, connection):
-        data_table = self.tables.data_table
 
         data = [
             {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
@@ -4507,23 +4510,23 @@ class JSONRoundTripTest(fixtures.TablesTest):
             {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2", "k3": 5}},
             {"name": "r6", "data": {"k1": {"r6v1": {"subr": [1, 2, 3]}}}},
         ]
-        connection.execute(data_table.insert(), data)
+        connection.execute(self.data_table.insert(), data)
         return data
 
     def _assert_data(self, compare, conn, column="data"):
-        col = self.tables.data_table.c[column]
+        col = self.data_table.c[column]
         data = conn.execute(
-            select(col).order_by(self.tables.data_table.c.name)
+            select(col).order_by(self.data_table.c.name)
         ).fetchall()
         eq_([d for d, in data], compare)
 
     def _assert_column_is_NULL(self, conn, column="data"):
-        col = self.tables.data_table.c[column]
+        col = self.data_table.c[column]
         data = conn.execute(select(col).where(col.is_(null()))).fetchall()
         eq_([d for d, in data], [None])
 
     def _assert_column_is_JSON_NULL(self, conn, column="data"):
-        col = self.tables.data_table.c[column]
+        col = self.data_table.c[column]
         data = conn.execute(
             select(col).where(cast(col, String) == "null")
         ).fetchall()
@@ -4539,7 +4542,7 @@ class JSONRoundTripTest(fixtures.TablesTest):
         argnames="key",
     )
     def test_indexed_special_keys(self, connection, key):
-        data_table = self.tables.data_table
+        data_table = self.data_table
         data_element = {key: "some value"}
 
         connection.execute(
@@ -4563,27 +4566,27 @@ class JSONRoundTripTest(fixtures.TablesTest):
 
     def test_insert(self, connection):
         connection.execute(
-            self.tables.data_table.insert(),
+            self.data_table.insert(),
             {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
         )
         self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], connection)
 
     def test_insert_nulls(self, connection):
         connection.execute(
-            self.tables.data_table.insert(), {"name": "r1", "data": null()}
+            self.data_table.insert(), {"name": "r1", "data": null()}
         )
         self._assert_data([None], connection)
 
     def test_insert_none_as_null(self, connection):
         connection.execute(
-            self.tables.data_table.insert(),
+            self.data_table.insert(),
             {"name": "r1", "nulldata": None},
         )
         self._assert_column_is_NULL(connection, column="nulldata")
 
     def test_insert_nulljson_into_none_as_null(self, connection):
         connection.execute(
-            self.tables.data_table.insert(),
+            self.data_table.insert(),
             {"name": "r1", "nulldata": JSON.NULL},
         )
         self._assert_column_is_JSON_NULL(connection, column="nulldata")
@@ -4654,9 +4657,8 @@ class JSONRoundTripTest(fixtures.TablesTest):
 
     def test_query_returned_as_text(self, connection):
         self._fixture_data(connection)
-        data_table = self.tables.data_table
         result = connection.execute(
-            select(data_table.c.data["k1"].astext)
+            select(self.data_table.c.data["k1"].astext)
         ).first()
         assert isinstance(result[0], str)
 
@@ -4795,6 +4797,21 @@ class JSONBRoundTripTest(JSONRoundTripTest):
     def test_unicode_round_trip(self, connection):
         super(JSONBRoundTripTest, self).test_unicode_round_trip(connection)
 
+    @testing.only_on("postgresql >= 12")
+    def test_cast_jsonpath(self, connection):
+        self._fixture_data(connection)
+
+        def go(path, res):
+            q = select(func.count("*")).where(
+                func.jsonb_path_exists(
+                    self.data_table.c.data, cast(path, JSONB.JSONPathType)
+                )
+            )
+            eq_(connection.scalar(q), res)
+
+        go("$.k1.k2", 0)
+        go("$.k1.r6v1", 1)
+
 
 class JSONBSuiteTest(suite.JSONTest):
     __requires__ = ("postgresql_jsonb",)