From: nathan Date: Mon, 9 Dec 2013 16:46:36 +0000 (-0500) Subject: sqlalchemy/dialects/postgresql/__init__.py: X-Git-Tag: rel_0_9_0~26^2~6^2~5 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=64288c7d6ffc021e2388aa764e9a3b921506c7a0;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git sqlalchemy/dialects/postgresql/__init__.py: - Added import references to JSON class sqlalchemy/dialects/postgresql/base.py: - Added visitor method for JSON class sqlalchemy/dialects/postgresql/pgjson (new): - JSON class, supports automatic serialization and deserialization of json data, as well as basic json operators. --- diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py index 408b678467..00bbc7268d 100644 --- a/lib/sqlalchemy/dialects/postgresql/__init__.py +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -14,6 +14,7 @@ from .base import \ DATE, BYTEA, BOOLEAN, INTERVAL, ARRAY, ENUM, dialect, array, Any, All from .constraints import ExcludeConstraint from .hstore import HSTORE, hstore +from .pgjson import JSON from .ranges import INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, \ TSTZRANGE @@ -23,5 +24,5 @@ __all__ = ( 'DOUBLE_PRECISION', 'TIMESTAMP', 'TIME', 'DATE', 'BYTEA', 'BOOLEAN', 'INTERVAL', 'ARRAY', 'ENUM', 'dialect', 'Any', 'All', 'array', 'HSTORE', 'hstore', 'INT4RANGE', 'INT8RANGE', 'NUMRANGE', 'DATERANGE', - 'TSRANGE', 'TSTZRANGE' + 'TSRANGE', 'TSTZRANGE', 'json', 'JSON' ) diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index b80f269c14..6469f3b702 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1187,6 +1187,9 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): def visit_HSTORE(self, type_): return "HSTORE" + def visit_JSON(self, type_): + return "JSON" + def visit_INT4RANGE(self, type_): return "INT4RANGE" diff --git a/lib/sqlalchemy/dialects/postgresql/pgjson.py b/lib/sqlalchemy/dialects/postgresql/pgjson.py new file mode 100644 index 0000000000..aef54709bb --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/pgjson.py @@ -0,0 +1,109 @@ +# postgresql/json.py +# Copyright (C) 2005-2013 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +import json + +from .base import ARRAY, ischema_names +from ... import types as sqltypes +from ...sql import functions as sqlfunc +from ...sql.operators import custom_op +from ... import util + +__all__ = ('JSON', 'json') + + +class JSON(sqltypes.TypeEngine): + """Represent the Postgresql HSTORE type. + + The :class:`.JSON` type stores arbitrary JSON format data, e.g.:: + + data_table = Table('data_table', metadata, + Column('id', Integer, primary_key=True), + Column('data', JSON) + ) + + with engine.connect() as conn: + conn.execute( + data_table.insert(), + data = {"key1": "value1", "key2": "value2"} + ) + + :class:`.JSON` provides two operations: + + * Index operations:: + + data_table.c.data['some key'] == 'some value' + + * Path Index operations:: + + data_table.c.data.get_path('{key_1, key_2, ..., key_n}'] + + Please be aware that when used with the SQL Alchemy ORM, you will need to + replace the JSON object present on an attribute with a new object in order + for any changes to be properly persisted. + + .. versionadded:: 0.9 + """ + + __visit_name__ = 'JSON' + + def __init__(self, json_serializer=None, json_deserializer=None): + if json_serializer: + self.json_serializer = json_serializer + else: + self.json_serializer = json.dumps + if json_deserializer: + self.json_deserializer = json_deserializer + else: + self.json_deserializer = json.loads + + class comparator_factory(sqltypes.Concatenable.Comparator): + """Define comparison operations for :class:`.JSON`.""" + + def __getitem__(self, other): + """Text expression. Get the value at a given key.""" + # I'm choosing to return text here so the result can be cast, + # compared with strings, etc. + # + # The only downside to this is that you cannot dereference more + # than one level deep in json structures, though comparator + # support for multi-level dereference is lacking anyhow. + return self.expr.op('->>', precedence=5)(other) + + def get_path(self, other): + """Text expression. Get the value at a given path. Paths are of + the form {key_1, key_2, ..., key_n}.""" + return self.expr.op('#>>', precedence=5)(other) + + def _adapt_expression(self, op, other_comparator): + if isinstance(op, custom_op): + if op.opstring == '->': + return op, sqltypes.Text + return sqltypes.Concatenable.Comparator.\ + _adapt_expression(self, op, other_comparator) + + def bind_processor(self, dialect): + if util.py2k: + encoding = dialect.encoding + def process(value): + return self.json_serializer(value).encode(encoding) + else: + def process(value): + return self.json_serializer(value) + return process + + def result_processor(self, dialect, coltype): + if util.py2k: + encoding = dialect.encoding + def process(value): + return self.json_deserializer(value.decode(encoding)) + else: + def process(value): + return self.json_deserializer(value) + return process + + +ischema_names['json'] = JSON diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 0675ebd5d3..5a944ae9d0 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -15,7 +15,8 @@ from sqlalchemy.orm import Session, mapper, aliased from sqlalchemy import exc, schema, types from sqlalchemy.dialects.postgresql import base as postgresql from sqlalchemy.dialects.postgresql import HSTORE, hstore, array, \ - INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, TSTZRANGE + INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, TSTZRANGE, \ + JSON import decimal from sqlalchemy import util from sqlalchemy.testing.util import round_decimal @@ -1651,3 +1652,157 @@ class DateTimeTZRangeTests(_RangeTypeMixin, fixtures.TablesTest): def _data_obj(self): return self.extras.DateTimeTZRange(*self.tstzs()) + + +class JSONTest(fixtures.TestBase): + def _assert_sql(self, construct, expected): + dialect = postgresql.dialect() + compiled = str(construct.compile(dialect=dialect)) + compiled = re.sub(r'\s+', ' ', compiled) + expected = re.sub(r'\s+', ' ', expected) + eq_(compiled, expected) + + def setup(self): + metadata = MetaData() + self.test_table = Table('test_table', metadata, + Column('id', Integer, primary_key=True), + Column('test_column', JSON) + ) + self.jsoncol = self.test_table.c.test_column + + def _test_where(self, whereclause, expected): + stmt = select([self.test_table]).where(whereclause) + self._assert_sql( + stmt, + "SELECT test_table.id, test_table.test_column FROM test_table " + "WHERE %s" % expected + ) + + def _test_cols(self, colclause, expected, from_=True): + stmt = select([colclause]) + self._assert_sql( + stmt, + ( + "SELECT %s" + + (" FROM test_table" if from_ else "") + ) % expected + ) + + def test_bind_serialize_default(self): + from sqlalchemy.engine import default + + dialect = default.DefaultDialect() + proc = self.test_table.c.test_column.type._cached_bind_processor(dialect) + eq_( + proc({"A": [1, 2, 3, True, False]}), + '{"A": [1, 2, 3, true, false]}' + ) + + def test_result_deserialize_default(self): + from sqlalchemy.engine import default + + dialect = default.DefaultDialect() + proc = self.test_table.c.test_column.type._cached_result_processor( + dialect, None) + eq_( + proc('{"A": [1, 2, 3, true, false]}'), + {"A": [1, 2, 3, True, False]} + ) + + # This test is a bit misleading -- in real life you will need to cast to do anything + def test_where_getitem(self): + self._test_where( + self.jsoncol['bar'] == None, + "(test_table.test_column ->> %(test_column_1)s) IS NULL" + ) + + def test_where_path(self): + self._test_where( + self.jsoncol.get_path('{"foo", 1}') == None, + "(test_table.test_column #>> %(test_column_1)s) IS NULL" + ) + + def test_cols_get(self): + self._test_cols( + self.jsoncol['foo'], + "test_table.test_column ->> %(test_column_1)s AS anon_1", + True + ) + + +class JSONRoundTripTest(fixtures.TablesTest): + __only_on__ = 'postgresql' + + @classmethod + def define_tables(cls, metadata): + Table('data_table', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(30), nullable=False), + Column('data', JSON) + ) + + def _fixture_data(self, engine): + data_table = self.tables.data_table + engine.execute( + data_table.insert(), + {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}}, + {'name': 'r2', 'data': {"k1": "r2v1", "k2": "r2v2"}}, + {'name': 'r3', 'data': {"k1": "r3v1", "k2": "r3v2"}}, + {'name': 'r4', 'data': {"k1": "r4v1", "k2": "r4v2"}}, + {'name': 'r5', 'data': {"k1": "r5v1", "k2": "r5v2"}}, + ) + + def _assert_data(self, compare): + data = testing.db.execute( + select([self.tables.data_table.c.data]). + order_by(self.tables.data_table.c.name) + ).fetchall() + eq_([d for d, in data], compare) + + def _test_insert(self, engine): + engine.execute( + self.tables.data_table.insert(), + {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}} + ) + self._assert_data([{"k1": "r1v1", "k2": "r1v2"}]) + + def _non_native_engine(self): + if testing.against("postgresql+psycopg2"): + engine = engines.testing_engine(options=dict(use_native_hstore=False)) + else: + engine = testing.db + engine.connect() + return engine + + def test_reflect(self): + from sqlalchemy import inspect + insp = inspect(testing.db) + cols = insp.get_columns('data_table') + assert isinstance(cols[2]['type'], JSON) + + @testing.only_on("postgresql+psycopg2") + def test_insert_native(self): + engine = testing.db + self._test_insert(engine) + + def test_insert_python(self): + engine = self._non_native_engine() + self._test_insert(engine) + + @testing.only_on("postgresql+psycopg2") + def test_criterion_native(self): + engine = testing.db + self._fixture_data(engine) + self._test_criterion(engine) + + def test_criterion_python(self): + engine = self._non_native_engine() + self._fixture_data(engine) + self._test_criterion(engine) + + def _test_criterion(self, engine): + data_table = self.tables.data_table + result = engine.execute( + select([data_table.c.data]).where(data_table.c.data['k1'] == 'r3v1') + ).first() + eq_(result, ({'k1': 'r3v1', 'k2': 'r3v2'},)) \ No newline at end of file