]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Merge branch 'issue_2581' of github.com:nathan-rice/sqlalchemy into pg_json
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 17 Dec 2013 19:03:20 +0000 (14:03 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 17 Dec 2013 19:03:20 +0000 (14:03 -0500)
1  2 
lib/sqlalchemy/dialects/postgresql/__init__.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/psycopg2.py
test/dialect/postgresql/test_types.py

index 3c1d19504d8c7ca0c1a141c4b3e54461e459fac9,00bbc7268ddba133bf1f0d16c9c30adfe5c03ebb..728f1629fba26fcd10c192abdfdef9a9f53d7fc3
@@@ -11,10 -11,10 +11,11 @@@ base.dialect = psycopg2.dialec
  from .base import \
      INTEGER, BIGINT, SMALLINT, VARCHAR, CHAR, TEXT, NUMERIC, FLOAT, REAL, \
      INET, CIDR, UUID, BIT, MACADDR, DOUBLE_PRECISION, TIMESTAMP, TIME, \
 -    DATE, BYTEA, BOOLEAN, INTERVAL, ARRAY, ENUM, dialect, array, Any, All
 +    DATE, BYTEA, BOOLEAN, INTERVAL, ARRAY, ENUM, dialect, array, Any, All, \
 +    TSVECTOR
  from .constraints import ExcludeConstraint
  from .hstore import HSTORE, hstore
+ from .pgjson import JSON
  from .ranges import INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, \
      TSTZRANGE
  
index c3c749523cbe6bde6cee12675e06046fb5e38eeb,0dbdfe8fbd0b97229778bed1dd9138cd88d155de..4a9248e5ffb8998a6055978001a8917631d804ca
@@@ -324,7 -347,9 +337,8 @@@ class PGDialect_psycopg2(PGDialect)
              sqltypes.Numeric: _PGNumeric,
              ENUM: _PGEnum,  # needs force_unicode
              sqltypes.Enum: _PGEnum,  # needs force_unicode
 -            ARRAY: _PGArray,  # needs force_unicode
              HSTORE: _PGHStore,
+             JSON: _PGJSON
          }
      )
  
index 6e8609448b3c6a4152a44d525ffcd16926daac5c,c7a973e4e0e91656f766ddda4cebf9de7c784548..19df131fd3c5b051171609d0d8ab967f19f2a37b
@@@ -1663,3 -1652,191 +1664,191 @@@ class DateTimeTZRangeTests(_RangeTypeMi
  
      def _data_obj(self):
          return self.extras.DateTimeTZRange(*self.tstzs())
 -        eq_(result, ({'k1': 'r3v1', 'k2': 'r3v2'},))
+ class JSONTest(fixtures.TestBase):
+     def _assert_sql(self, construct, expected):
+         dialect = postgresql.dialect()
+         compiled = str(construct.compile(dialect=dialect))
+         compiled = re.sub(r'\s+', ' ', compiled)
+         expected = re.sub(r'\s+', ' ', expected)
+         eq_(compiled, expected)
+     def setup(self):
+         metadata = MetaData()
+         self.test_table = Table('test_table', metadata,
+             Column('id', Integer, primary_key=True),
+             Column('test_column', JSON)
+         )
+         self.jsoncol = self.test_table.c.test_column
+     def _test_where(self, whereclause, expected):
+         stmt = select([self.test_table]).where(whereclause)
+         self._assert_sql(
+             stmt,
+             "SELECT test_table.id, test_table.test_column FROM test_table "
+             "WHERE %s" % expected
+         )
+     def _test_cols(self, colclause, expected, from_=True):
+         stmt = select([colclause])
+         self._assert_sql(
+             stmt,
+             (
+                 "SELECT %s" +
+                 (" FROM test_table" if from_ else "")
+             ) % expected
+         )
+     def test_bind_serialize_default(self):
+         from sqlalchemy.engine import default
+         dialect = default.DefaultDialect()
+         proc = self.test_table.c.test_column.type._cached_bind_processor(dialect)
+         eq_(
+             proc({"A": [1, 2, 3, True, False]}),
+             '{"A": [1, 2, 3, true, false]}'
+         )
+     def test_result_deserialize_default(self):
+         from sqlalchemy.engine import default
+         dialect = default.DefaultDialect()
+         proc = self.test_table.c.test_column.type._cached_result_processor(
+                     dialect, None)
+         eq_(
+             proc('{"A": [1, 2, 3, true, false]}'),
+             {"A": [1, 2, 3, True, False]}
+         )
+     # This test is a bit misleading -- in real life you will need to cast to do anything
+     def test_where_getitem(self):
+         self._test_where(
+             self.jsoncol['bar'] == None,
+             "(test_table.test_column -> %(test_column_1)s) IS NULL"
+         )
+     def test_where_path(self):
+         self._test_where(
+             self.jsoncol.get_path('{"foo", 1}') == None,
+             "(test_table.test_column #> %(test_column_1)s) IS NULL"
+         )
+     def test_where_getitem_as_text(self):
+         self._test_where(
+             self.jsoncol.get_item_as_text('bar') == None,
+             "(test_table.test_column ->> %(test_column_1)s) IS NULL"
+         )
+     def test_where_path_as_text(self):
+         self._test_where(
+             self.jsoncol.get_path_as_text('{"foo", 1}') == None,
+             "(test_table.test_column #>> %(test_column_1)s) IS NULL"
+         )
+     def test_cols_get(self):
+         self._test_cols(
+             self.jsoncol['foo'],
+             "test_table.test_column -> %(test_column_1)s AS anon_1",
+             True
+         )
+ class JSONRoundTripTest(fixtures.TablesTest):
+     __only_on__ = 'postgresql'
+     @classmethod
+     def define_tables(cls, metadata):
+         Table('data_table', metadata,
+             Column('id', Integer, primary_key=True),
+             Column('name', String(30), nullable=False),
+             Column('data', JSON)
+         )
+     def _fixture_data(self, engine):
+         data_table = self.tables.data_table
+         engine.execute(
+                 data_table.insert(),
+                 {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}},
+                 {'name': 'r2', 'data': {"k1": "r2v1", "k2": "r2v2"}},
+                 {'name': 'r3', 'data': {"k1": "r3v1", "k2": "r3v2"}},
+                 {'name': 'r4', 'data': {"k1": "r4v1", "k2": "r4v2"}},
+                 {'name': 'r5', 'data': {"k1": "r5v1", "k2": "r5v2"}},
+         )
+     def _assert_data(self, compare):
+         data = testing.db.execute(
+             select([self.tables.data_table.c.data]).
+                 order_by(self.tables.data_table.c.name)
+         ).fetchall()
+         eq_([d for d, in data], compare)
+     def _test_insert(self, engine):
+         engine.execute(
+             self.tables.data_table.insert(),
+             {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}}
+         )
+         self._assert_data([{"k1": "r1v1", "k2": "r1v2"}])
+     def _non_native_engine(self):
+         if testing.against("postgresql+psycopg2"):
+             engine = engines.testing_engine()
+         else:
+             engine = testing.db
+         engine.connect()
+         return engine
+     def test_reflect(self):
+         from sqlalchemy import inspect
+         insp = inspect(testing.db)
+         cols = insp.get_columns('data_table')
+         assert isinstance(cols[2]['type'], JSON)
+     @testing.only_on("postgresql+psycopg2")
+     def test_insert_native(self):
+         engine = testing.db
+         self._test_insert(engine)
+     def test_insert_python(self):
+         engine = self._non_native_engine()
+         self._test_insert(engine)
+     @testing.only_on("postgresql+psycopg2")
+     def test_criterion_native(self):
+         engine = testing.db
+         self._fixture_data(engine)
+         self._test_criterion(engine)
+     def test_criterion_python(self):
+         engine = self._non_native_engine()
+         self._fixture_data(engine)
+         self._test_criterion(engine)
+     def test_path_query(self):
+         engine = testing.db
+         self._fixture_data(engine)
+         data_table = self.tables.data_table
+         result = engine.execute(
+             select([data_table.c.data]).where(
+                 data_table.c.data.get_path_as_text('{k1}') == 'r3v1'
+             )
+         ).first()
+         eq_(result, ({'k1': 'r3v1', 'k2': 'r3v2'},))
+     def test_query_returned_as_text(self):
+         engine = testing.db
+         self._fixture_data(engine)
+         data_table = self.tables.data_table
+         result = engine.execute(
+             select([data_table.c.data.get_item_as_text('k1')])
+         ).first()
+         assert isinstance(result[0], basestring)
+     def _test_criterion(self, engine):
+         data_table = self.tables.data_table
+         result = engine.execute(
+             select([data_table.c.data]).where(
+                 data_table.c.data.get_item_as_text('k1') == 'r3v1'
+             )
+         ).first()
++        eq_(result, ({'k1': 'r3v1', 'k2': 'r3v2'},))