]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- rework JSON expressions to be based off __getitem__ exclusively
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 17 Dec 2013 20:13:39 +0000 (15:13 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 17 Dec 2013 20:13:39 +0000 (15:13 -0500)
- add support for "standalone" JSON objects; this involves getting CAST
to upgrade the given type of a bound parameter.  should add a core-only test
for this.
- add tests for "standalone" json round trips both with and without unicode
- add mechanism by which we remove psycopg2's "json" handler in order to get
the effect of using our non-native result handlers

lib/sqlalchemy/dialects/postgresql/__init__.py
lib/sqlalchemy/dialects/postgresql/json.py [moved from lib/sqlalchemy/dialects/postgresql/pgjson.py with 60% similarity]
lib/sqlalchemy/dialects/postgresql/psycopg2.py
lib/sqlalchemy/sql/elements.py
test/dialect/postgresql/test_types.py

index 728f1629fba26fcd10c192abdfdef9a9f53d7fc3..cfe1ebce0bde17beecd4aa0b788ec6f7a285744b 100644 (file)
@@ -15,7 +15,7 @@ from .base import \
     TSVECTOR
 from .constraints import ExcludeConstraint
 from .hstore import HSTORE, hstore
-from .pgjson import JSON
+from .json import JSON
 from .ranges import INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, \
     TSTZRANGE
 
similarity index 60%
rename from lib/sqlalchemy/dialects/postgresql/pgjson.py
rename to lib/sqlalchemy/dialects/postgresql/json.py
index a29d0bbcc096458a3b384f91854ec02788d13169..5b8ad68f59379665d3b0c7f10dcb8429a092e1a0 100644 (file)
@@ -3,16 +3,16 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
+from __future__ import absolute_import
 
 import json
 
-from .base import ARRAY, ischema_names
+from .base import ischema_names
 from ... import types as sqltypes
-from ...sql import functions as sqlfunc
 from ...sql.operators import custom_op
 from ... import util
 
-__all__ = ('JSON', 'json')
+__all__ = ('JSON', )
 
 
 class JSON(sqltypes.TypeEngine):
@@ -39,21 +39,25 @@ class JSON(sqltypes.TypeEngine):
 
     * Index operations returning text (required for text comparison or casting)::
 
-        data_table.c.data.get_item_as_text('some key') == 'some value'
+        data_table.c.data.astext['some key'] == 'some value'
 
     * Path index operations::
 
-        data_table.c.data.get_path("{key_1, key_2, ..., key_n}")
+        data_table.c.data[('key_1', 'key_2', ..., 'key_n')]
 
     * Path index operations returning text (required for text comparison or casting)::
 
-        data_table.c.data.get_path("{key_1, key_2, ..., key_n}") == 'some value'
+        data_table.c.data.astext[('key_1', 'key_2', ..., 'key_n')] == 'some value'
 
-    Please be aware that when used with the SQLAlchemy ORM, you will need to
-    replace the JSON object present on an attribute with a new object in order
-    for any changes to be properly persisted.
+    The :class:`.JSON` type, when used with the SQLAlchemy ORM, does not detect
+    in-place mutations to the structure.  In order to detect these, the
+    :mod:`sqlalchemy.ext.mutable` extension must be used.  This extension will
+    allow "in-place" changes to the datastructure to produce events which
+    will be detected by the unit of work.  See the example at :class:`.HSTORE`
+    for a simple example involving a dictionary.
 
     .. versionadded:: 0.9
+
     """
 
     __visit_name__ = 'JSON'
@@ -71,31 +75,35 @@ class JSON(sqltypes.TypeEngine):
     class comparator_factory(sqltypes.Concatenable.Comparator):
         """Define comparison operations for :class:`.JSON`."""
 
+        class _astext(object):
+            def __init__(self, parent):
+                self.parent = parent
+
+            def __getitem__(self, other):
+                return self.parent.expr._get_item(other, True)
+
+        def _get_item(self, other, astext):
+            if hasattr(other, '__iter__') and \
+                not isinstance(other, util.string_types):
+                op = "#>"
+                other = "{%s}" % (", ".join(util.text_type(elem) for elem in other))
+            else:
+                op = "->"
+
+            if astext:
+                op += ">"
+
+            # ops: ->, ->>, #>, #>>
+            return self.expr.op(op, precedence=5)(other)
+
         def __getitem__(self, other):
-            """Text expression.  Get the value at a given key."""
-            # I'm choosing to return text here so the result can be cast,
-            # compared with strings, etc.
-            #
-            # The only downside to this is that you cannot dereference more
-            # than one level deep in json structures, though comparator
-            # support for multi-level dereference is lacking anyhow.
-            return self.expr.op('->', precedence=5)(other)
-
-        def get_item_as_text(self, other):
-            """Text expression.  Get the value at the given key as text.  Use
-            this when you need to cast the type of the returned value."""
-            return self.expr.op('->>', precedence=5)(other)
-
-        def get_path(self, other):
-            """Text expression.  Get the value at a given path. Paths are of
-            the form {key_1, key_2, ..., key_n}."""
-            return self.expr.op('#>', precedence=5)(other)
-
-        def get_path_as_text(self, other):
-            """Text expression.  Get the value at a given path, as text.
-            Paths are of the form {key_1, key_2, ..., key_n}.  Use this when
-            you need to cast the type of the returned value."""
-            return self.expr.op('#>>', precedence=5)(other)
+            """Get the value at a given key."""
+
+            return self._get_item(other, False)
+
+        @property
+        def astext(self):
+            return self._astext(self)
 
         def _adapt_expression(self, op, other_comparator):
             if isinstance(op, custom_op):
index 4a9248e5ffb8998a6055978001a8917631d804ca..ceb04b5801d60e882bb436ac3ed04761c9091245 100644 (file)
@@ -179,7 +179,7 @@ from .base import PGDialect, PGCompiler, \
                                 ENUM, ARRAY, _DECIMAL_TYPES, _FLOAT_TYPES,\
                                 _INT_TYPES
 from .hstore import HSTORE
-from .pgjson import JSON
+from .json import JSON
 
 
 logger = logging.getLogger('sqlalchemy.dialects.postgresql')
@@ -236,9 +236,7 @@ class _PGHStore(HSTORE):
 
 
 class _PGJSON(JSON):
-    # I've omitted the bind processor here because the method of serializing
-    # involves registering specific types to auto-serialize, and the adapter
-    # just a thin wrapper over json.dumps.
+
     def result_processor(self, dialect, coltype):
         if dialect._has_native_json:
             return None
index 045056b42de5710be67114e40c054ab3117523a3..69e365bd350296c9fdff8ada8b0e4e4ce842df0c 100644 (file)
@@ -1753,6 +1753,10 @@ class Cast(ColumnElement):
         """
         self.type = type_api.to_instance(totype)
         self.clause = _literal_as_binds(clause, None)
+        if isinstance(self.clause, BindParameter) and self.clause.type._isnull:
+            self.clause = self.clause._clone()
+            self.clause.type = self.type
+
         self.typeclause = TypeClause(self.type)
 
     def _copy_internals(self, clone=_clone, **kw):
index 19df131fd3c5b051171609d0d8ab967f19f2a37b..5da2520f36cbafccb4b10f6a14e68607d54cb3e7 100644 (file)
@@ -10,7 +10,8 @@ from sqlalchemy import Table, Column, select, MetaData, text, Integer, \
             PrimaryKeyConstraint, DateTime, tuple_, Float, BigInteger, \
             func, literal_column, literal, bindparam, cast, extract, \
             SmallInteger, Enum, REAL, update, insert, Index, delete, \
-            and_, Date, TypeDecorator, Time, Unicode, Interval, or_, Text
+            and_, Date, TypeDecorator, Time, Unicode, Interval, or_, Text, \
+            type_coerce
 from sqlalchemy.orm import Session, mapper, aliased
 from sqlalchemy import exc, schema, types
 from sqlalchemy.dialects.postgresql import base as postgresql
@@ -23,6 +24,8 @@ from sqlalchemy.testing.util import round_decimal
 from sqlalchemy.sql import table, column, operators
 import logging
 import re
+from sqlalchemy import inspect
+from sqlalchemy import event
 
 class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults):
     __only_on__ = 'postgresql'
@@ -965,14 +968,7 @@ class UUIDTest(fixtures.TestBase):
 
 
 
-class HStoreTest(fixtures.TestBase):
-    def _assert_sql(self, construct, expected):
-        dialect = postgresql.dialect()
-        compiled = str(construct.compile(dialect=dialect))
-        compiled = re.sub(r'\s+', ' ', compiled)
-        expected = re.sub(r'\s+', ' ', expected)
-        eq_(compiled, expected)
-
+class HStoreTest(AssertsCompiledSQL, fixtures.TestBase):
     def setup(self):
         metadata = MetaData()
         self.test_table = Table('test_table', metadata,
@@ -983,7 +979,7 @@ class HStoreTest(fixtures.TestBase):
 
     def _test_where(self, whereclause, expected):
         stmt = select([self.test_table]).where(whereclause)
-        self._assert_sql(
+        self.assert_compile(
             stmt,
             "SELECT test_table.id, test_table.hash FROM test_table "
             "WHERE %s" % expected
@@ -991,7 +987,7 @@ class HStoreTest(fixtures.TestBase):
 
     def _test_cols(self, colclause, expected, from_=True):
         stmt = select([colclause])
-        self._assert_sql(
+        self.assert_compile(
             stmt,
             (
                 "SELECT %s" +
@@ -1292,7 +1288,6 @@ class HStoreRoundTripTest(fixtures.TablesTest):
         return engine
 
     def test_reflect(self):
-        from sqlalchemy import inspect
         insp = inspect(testing.db)
         cols = insp.get_columns('data_table')
         assert isinstance(cols[2]['type'], HSTORE)
@@ -1666,13 +1661,7 @@ class DateTimeTZRangeTests(_RangeTypeMixin, fixtures.TablesTest):
         return self.extras.DateTimeTZRange(*self.tstzs())
 
 
-class JSONTest(fixtures.TestBase):
-    def _assert_sql(self, construct, expected):
-        dialect = postgresql.dialect()
-        compiled = str(construct.compile(dialect=dialect))
-        compiled = re.sub(r'\s+', ' ', compiled)
-        expected = re.sub(r'\s+', ' ', expected)
-        eq_(compiled, expected)
+class JSONTest(AssertsCompiledSQL, fixtures.TestBase):
 
     def setup(self):
         metadata = MetaData()
@@ -1684,7 +1673,7 @@ class JSONTest(fixtures.TestBase):
 
     def _test_where(self, whereclause, expected):
         stmt = select([self.test_table]).where(whereclause)
-        self._assert_sql(
+        self.assert_compile(
             stmt,
             "SELECT test_table.id, test_table.test_column FROM test_table "
             "WHERE %s" % expected
@@ -1692,7 +1681,7 @@ class JSONTest(fixtures.TestBase):
 
     def _test_cols(self, colclause, expected, from_=True):
         stmt = select([colclause])
-        self._assert_sql(
+        self.assert_compile(
             stmt,
             (
                 "SELECT %s" +
@@ -1730,19 +1719,19 @@ class JSONTest(fixtures.TestBase):
 
     def test_where_path(self):
         self._test_where(
-            self.jsoncol.get_path('{"foo", 1}') == None,
+            self.jsoncol[("foo", 1)] == None,
             "(test_table.test_column #> %(test_column_1)s) IS NULL"
         )
 
     def test_where_getitem_as_text(self):
         self._test_where(
-            self.jsoncol.get_item_as_text('bar') == None,
+            self.jsoncol.astext['bar'] == None,
             "(test_table.test_column ->> %(test_column_1)s) IS NULL"
         )
 
     def test_where_path_as_text(self):
         self._test_where(
-            self.jsoncol.get_path_as_text('{"foo", 1}') == None,
+            self.jsoncol.astext[("foo", 1)] == None,
             "(test_table.test_column #>> %(test_column_1)s) IS NULL"
         )
 
@@ -1755,7 +1744,7 @@ class JSONTest(fixtures.TestBase):
 
 
 class JSONRoundTripTest(fixtures.TablesTest):
-    __only_on__ = 'postgresql'
+    __only_on__ = ('postgresql >= 9.3',)
 
     @classmethod
     def define_tables(cls, metadata):
@@ -1792,14 +1781,20 @@ class JSONRoundTripTest(fixtures.TablesTest):
 
     def _non_native_engine(self):
         if testing.against("postgresql+psycopg2"):
+            from psycopg2.extras import register_default_json
             engine = engines.testing_engine()
+            @event.listens_for(engine, "connect")
+            def connect(dbapi_connection, connection_record):
+                engine.dialect._has_native_json = False
+                def pass_(value):
+                    return value
+                register_default_json(dbapi_connection, loads=pass_)
         else:
             engine = testing.db
         engine.connect()
         return engine
 
     def test_reflect(self):
-        from sqlalchemy import inspect
         insp = inspect(testing.db)
         cols = insp.get_columns('data_table')
         assert isinstance(cols[2]['type'], JSON)
@@ -1830,7 +1825,7 @@ class JSONRoundTripTest(fixtures.TablesTest):
         data_table = self.tables.data_table
         result = engine.execute(
             select([data_table.c.data]).where(
-                data_table.c.data.get_path_as_text('{k1}') == 'r3v1'
+                data_table.c.data.astext[('k1',)] == 'r3v1'
             )
         ).first()
         eq_(result, ({'k1': 'r3v1', 'k2': 'r3v2'},))
@@ -1840,7 +1835,7 @@ class JSONRoundTripTest(fixtures.TablesTest):
         self._fixture_data(engine)
         data_table = self.tables.data_table
         result = engine.execute(
-            select([data_table.c.data.get_item_as_text('k1')])
+            select([data_table.c.data.astext['k1']])
         ).first()
         assert isinstance(result[0], basestring)
 
@@ -1848,7 +1843,61 @@ class JSONRoundTripTest(fixtures.TablesTest):
         data_table = self.tables.data_table
         result = engine.execute(
             select([data_table.c.data]).where(
-                data_table.c.data.get_item_as_text('k1') == 'r3v1'
+                data_table.c.data.astext['k1'] == 'r3v1'
             )
         ).first()
         eq_(result, ({'k1': 'r3v1', 'k2': 'r3v2'},))
+
+    def _test_fixed_round_trip(self, engine):
+        s = select([
+                cast(
+                    {
+                        "key": "value",
+                        "key2": {"k1": "v1", "k2": "v2"}
+                    },
+                    JSON
+                )
+            ])
+        eq_(
+            engine.scalar(s),
+            {
+                "key": "value",
+                "key2": {"k1": "v1", "k2": "v2"}
+            },
+        )
+
+    def test_fixed_round_trip_python(self):
+        engine = self._non_native_engine()
+        self._test_fixed_round_trip(engine)
+
+    @testing.only_on("postgresql+psycopg2")
+    def test_fixed_round_trip_native(self):
+        engine = testing.db
+        self._test_fixed_round_trip(engine)
+
+    def _test_unicode_round_trip(self, engine):
+        s = select([
+            cast(
+                {
+                    util.u('réveillé'): util.u('réveillé'),
+                    "data": {"k1": util.u('drôle')}
+                },
+                JSON
+            )
+        ])
+        eq_(
+            engine.scalar(s),
+                {
+                    util.u('réveillé'): util.u('réveillé'),
+                    "data": {"k1": util.u('drôle')}
+                },
+        )
+
+    def test_unicode_round_trip_python(self):
+        engine = self._non_native_engine()
+        self._test_unicode_round_trip(engine)
+
+    @testing.only_on("postgresql+psycopg2")
+    def test_unicode_round_trip_native(self):
+        engine = testing.db
+        self._test_unicode_round_trip(engine)