]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
sqlalchemy/dialects/postgresql/__init__.py:
authornathan <nathan.alexander.rice@gmail.com>
Mon, 9 Dec 2013 16:46:36 +0000 (11:46 -0500)
committernathan <nathan.alexander.rice@gmail.com>
Mon, 9 Dec 2013 16:46:36 +0000 (11:46 -0500)
- Added import references to JSON class

 sqlalchemy/dialects/postgresql/base.py:
 - Added visitor method for JSON class

 sqlalchemy/dialects/postgresql/pgjson (new):
 - JSON class, supports automatic serialization and deserialization of json data, as well as basic json operators.

lib/sqlalchemy/dialects/postgresql/__init__.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/pgjson.py [new file with mode: 0644]
test/dialect/postgresql/test_types.py

index 408b678467ded92201554db7dc0013f8f9b01e2e..00bbc7268ddba133bf1f0d16c9c30adfe5c03ebb 100644 (file)
@@ -14,6 +14,7 @@ from .base import \
     DATE, BYTEA, BOOLEAN, INTERVAL, ARRAY, ENUM, dialect, array, Any, All
 from .constraints import ExcludeConstraint
 from .hstore import HSTORE, hstore
+from .pgjson import JSON
 from .ranges import INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, \
     TSTZRANGE
 
@@ -23,5 +24,5 @@ __all__ = (
     'DOUBLE_PRECISION', 'TIMESTAMP', 'TIME', 'DATE', 'BYTEA', 'BOOLEAN',
     'INTERVAL', 'ARRAY', 'ENUM', 'dialect', 'Any', 'All', 'array', 'HSTORE',
     'hstore', 'INT4RANGE', 'INT8RANGE', 'NUMRANGE', 'DATERANGE',
-    'TSRANGE', 'TSTZRANGE'
+    'TSRANGE', 'TSTZRANGE', 'json', 'JSON'
 )
index b80f269c149398d8fae499bf69f79896e78a250c..6469f3b7026f783a9ef07f2104549f6013c3cea5 100644 (file)
@@ -1187,6 +1187,9 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
     def visit_HSTORE(self, type_):
         return "HSTORE"
 
+    def visit_JSON(self, type_):
+        return "JSON"
+
     def visit_INT4RANGE(self, type_):
         return "INT4RANGE"
 
diff --git a/lib/sqlalchemy/dialects/postgresql/pgjson.py b/lib/sqlalchemy/dialects/postgresql/pgjson.py
new file mode 100644 (file)
index 0000000..aef5470
--- /dev/null
@@ -0,0 +1,109 @@
+# postgresql/json.py
+# Copyright (C) 2005-2013 the SQLAlchemy authors and contributors <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+import json
+
+from .base import ARRAY, ischema_names
+from ... import types as sqltypes
+from ...sql import functions as sqlfunc
+from ...sql.operators import custom_op
+from ... import util
+
+__all__ = ('JSON', 'json')
+
+
+class JSON(sqltypes.TypeEngine):
+    """Represent the Postgresql HSTORE type.
+
+    The :class:`.JSON` type stores arbitrary JSON format data, e.g.::
+
+        data_table = Table('data_table', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('data', JSON)
+        )
+
+        with engine.connect() as conn:
+            conn.execute(
+                data_table.insert(),
+                data = {"key1": "value1", "key2": "value2"}
+            )
+
+    :class:`.JSON` provides two operations:
+
+    * Index operations::
+
+        data_table.c.data['some key'] == 'some value'
+
+    * Path Index operations::
+
+        data_table.c.data.get_path('{key_1, key_2, ..., key_n}']
+
+    Please be aware that when used with the SQL Alchemy ORM, you will need to
+    replace the JSON object present on an attribute with a new object in order
+    for any changes to be properly persisted.
+
+    .. versionadded:: 0.9
+    """
+
+    __visit_name__ = 'JSON'
+
+    def __init__(self, json_serializer=None, json_deserializer=None):
+        if json_serializer:
+            self.json_serializer = json_serializer
+        else:
+            self.json_serializer = json.dumps
+        if json_deserializer:
+            self.json_deserializer = json_deserializer
+        else:
+            self.json_deserializer = json.loads
+
+    class comparator_factory(sqltypes.Concatenable.Comparator):
+        """Define comparison operations for :class:`.JSON`."""
+
+        def __getitem__(self, other):
+            """Text expression.  Get the value at a given key."""
+            # I'm choosing to return text here so the result can be cast,
+            # compared with strings, etc.
+            #
+            # The only downside to this is that you cannot dereference more
+            # than one level deep in json structures, though comparator
+            # support for multi-level dereference is lacking anyhow.
+            return self.expr.op('->>', precedence=5)(other)
+
+        def get_path(self, other):
+            """Text expression.  Get the value at a given path.  Paths are of
+            the form {key_1, key_2, ..., key_n}."""
+            return self.expr.op('#>>', precedence=5)(other)
+
+        def _adapt_expression(self, op, other_comparator):
+            if isinstance(op, custom_op):
+                if op.opstring == '->':
+                    return op, sqltypes.Text
+            return sqltypes.Concatenable.Comparator.\
+                _adapt_expression(self, op, other_comparator)
+
+    def bind_processor(self, dialect):
+        if util.py2k:
+            encoding = dialect.encoding
+            def process(value):
+                return self.json_serializer(value).encode(encoding)
+        else:
+            def process(value):
+                return self.json_serializer(value)
+        return process
+
+    def result_processor(self, dialect, coltype):
+        if util.py2k:
+            encoding = dialect.encoding
+            def process(value):
+                return self.json_deserializer(value.decode(encoding))
+        else:
+            def process(value):
+                return self.json_deserializer(value)
+        return process
+
+
+ischema_names['json'] = JSON
index 0675ebd5d39ee068780d18d1c09260f2b50a7f27..5a944ae9d0d27aa2b78b145c4342e7a669da412a 100644 (file)
@@ -15,7 +15,8 @@ from sqlalchemy.orm import Session, mapper, aliased
 from sqlalchemy import exc, schema, types
 from sqlalchemy.dialects.postgresql import base as postgresql
 from sqlalchemy.dialects.postgresql import HSTORE, hstore, array, \
-            INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, TSTZRANGE
+            INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, TSTZRANGE, \
+            JSON
 import decimal
 from sqlalchemy import util
 from sqlalchemy.testing.util import round_decimal
@@ -1651,3 +1652,157 @@ class DateTimeTZRangeTests(_RangeTypeMixin, fixtures.TablesTest):
 
     def _data_obj(self):
         return self.extras.DateTimeTZRange(*self.tstzs())
+
+
+class JSONTest(fixtures.TestBase):
+    def _assert_sql(self, construct, expected):
+        dialect = postgresql.dialect()
+        compiled = str(construct.compile(dialect=dialect))
+        compiled = re.sub(r'\s+', ' ', compiled)
+        expected = re.sub(r'\s+', ' ', expected)
+        eq_(compiled, expected)
+
+    def setup(self):
+        metadata = MetaData()
+        self.test_table = Table('test_table', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('test_column', JSON)
+        )
+        self.jsoncol = self.test_table.c.test_column
+
+    def _test_where(self, whereclause, expected):
+        stmt = select([self.test_table]).where(whereclause)
+        self._assert_sql(
+            stmt,
+            "SELECT test_table.id, test_table.test_column FROM test_table "
+            "WHERE %s" % expected
+        )
+
+    def _test_cols(self, colclause, expected, from_=True):
+        stmt = select([colclause])
+        self._assert_sql(
+            stmt,
+            (
+                "SELECT %s" +
+                (" FROM test_table" if from_ else "")
+            ) % expected
+        )
+
+    def test_bind_serialize_default(self):
+        from sqlalchemy.engine import default
+
+        dialect = default.DefaultDialect()
+        proc = self.test_table.c.test_column.type._cached_bind_processor(dialect)
+        eq_(
+            proc({"A": [1, 2, 3, True, False]}),
+            '{"A": [1, 2, 3, true, false]}'
+        )
+
+    def test_result_deserialize_default(self):
+        from sqlalchemy.engine import default
+
+        dialect = default.DefaultDialect()
+        proc = self.test_table.c.test_column.type._cached_result_processor(
+                    dialect, None)
+        eq_(
+            proc('{"A": [1, 2, 3, true, false]}'),
+            {"A": [1, 2, 3, True, False]}
+        )
+
+    # This test is a bit misleading -- in real life you will need to cast to do anything
+    def test_where_getitem(self):
+        self._test_where(
+            self.jsoncol['bar'] == None,
+            "(test_table.test_column ->> %(test_column_1)s) IS NULL"
+        )
+
+    def test_where_path(self):
+        self._test_where(
+            self.jsoncol.get_path('{"foo", 1}') == None,
+            "(test_table.test_column #>> %(test_column_1)s) IS NULL"
+        )
+
+    def test_cols_get(self):
+        self._test_cols(
+            self.jsoncol['foo'],
+            "test_table.test_column ->> %(test_column_1)s AS anon_1",
+            True
+        )
+
+
+class JSONRoundTripTest(fixtures.TablesTest):
+    __only_on__ = 'postgresql'
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table('data_table', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('name', String(30), nullable=False),
+            Column('data', JSON)
+        )
+
+    def _fixture_data(self, engine):
+        data_table = self.tables.data_table
+        engine.execute(
+                data_table.insert(),
+                {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}},
+                {'name': 'r2', 'data': {"k1": "r2v1", "k2": "r2v2"}},
+                {'name': 'r3', 'data': {"k1": "r3v1", "k2": "r3v2"}},
+                {'name': 'r4', 'data': {"k1": "r4v1", "k2": "r4v2"}},
+                {'name': 'r5', 'data': {"k1": "r5v1", "k2": "r5v2"}},
+        )
+
+    def _assert_data(self, compare):
+        data = testing.db.execute(
+            select([self.tables.data_table.c.data]).
+                order_by(self.tables.data_table.c.name)
+        ).fetchall()
+        eq_([d for d, in data], compare)
+
+    def _test_insert(self, engine):
+        engine.execute(
+            self.tables.data_table.insert(),
+            {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}}
+        )
+        self._assert_data([{"k1": "r1v1", "k2": "r1v2"}])
+
+    def _non_native_engine(self):
+        if testing.against("postgresql+psycopg2"):
+            engine = engines.testing_engine(options=dict(use_native_hstore=False))
+        else:
+            engine = testing.db
+        engine.connect()
+        return engine
+
+    def test_reflect(self):
+        from sqlalchemy import inspect
+        insp = inspect(testing.db)
+        cols = insp.get_columns('data_table')
+        assert isinstance(cols[2]['type'], JSON)
+
+    @testing.only_on("postgresql+psycopg2")
+    def test_insert_native(self):
+        engine = testing.db
+        self._test_insert(engine)
+
+    def test_insert_python(self):
+        engine = self._non_native_engine()
+        self._test_insert(engine)
+
+    @testing.only_on("postgresql+psycopg2")
+    def test_criterion_native(self):
+        engine = testing.db
+        self._fixture_data(engine)
+        self._test_criterion(engine)
+
+    def test_criterion_python(self):
+        engine = self._non_native_engine()
+        self._fixture_data(engine)
+        self._test_criterion(engine)
+
+    def _test_criterion(self, engine):
+        data_table = self.tables.data_table
+        result = engine.execute(
+            select([data_table.c.data]).where(data_table.c.data['k1'] == 'r3v1')
+        ).first()
+        eq_(result, ({'k1': 'r3v1', 'k2': 'r3v2'},))
\ No newline at end of file