]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- make the json serializer and deserializer per-dialect, so that we are
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 17 Dec 2013 22:46:09 +0000 (17:46 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 17 Dec 2013 22:46:09 +0000 (17:46 -0500)
compatible with psycopg2's per-connection/cursor approach.  add round trip tests for
both native and non-native.

lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/json.py
lib/sqlalchemy/dialects/postgresql/psycopg2.py
test/dialect/postgresql/test_types.py

index a7f8380093653596fed46f521407c3155ad6345c..3edc28fed5e4b6240419d5e501b1a9311a227bd9 100644 (file)
@@ -1439,9 +1439,12 @@ class PGDialect(default.DefaultDialect):
 
     _backslash_escapes = True
 
-    def __init__(self, isolation_level=None, **kwargs):
+    def __init__(self, isolation_level=None, json_serializer=None,
+                    json_deserializer=None, **kwargs):
         default.DefaultDialect.__init__(self, **kwargs)
         self.isolation_level = isolation_level
+        self._json_deserializer = json_deserializer
+        self._json_serializer = json_serializer
 
     def initialize(self, connection):
         super(PGDialect, self).initialize(connection)
index 5b8ad68f59379665d3b0c7f10dcb8429a092e1a0..7ba8b1abeea9867b3a11a7843eb89bdad8afe5e1 100644 (file)
@@ -56,22 +56,25 @@ class JSON(sqltypes.TypeEngine):
     will be detected by the unit of work.  See the example at :class:`.HSTORE`
     for a simple example involving a dictionary.
 
+    Custom serializers and deserializers are specified at the dialect level,
+    that is using :func:`.create_engine`.  The reason for this is that when
+    using psycopg2, the DBAPI only allows serializers at the per-cursor
+    or per-connection level.   E.g.::
+
+        engine = create_engine("postgresql://scott:tiger@localhost/test",
+                                json_serializer=my_serialize_fn,
+                                json_deserializer=my_deserialize_fn
+                        )
+
+    When using the psycopg2 dialect, the json_deserializer is registered
+    against the database using ``psycopg2.extras.register_default_json``.
+
     .. versionadded:: 0.9
 
     """
 
     __visit_name__ = 'JSON'
 
-    def __init__(self, json_serializer=None, json_deserializer=None):
-        if json_serializer:
-            self.json_serializer = json_serializer
-        else:
-            self.json_serializer = json.dumps
-        if json_deserializer:
-            self.json_deserializer = json_deserializer
-        else:
-            self.json_deserializer = json.loads
-
     class comparator_factory(sqltypes.Concatenable.Comparator):
         """Define comparison operations for :class:`.JSON`."""
 
@@ -113,23 +116,25 @@ class JSON(sqltypes.TypeEngine):
                 _adapt_expression(self, op, other_comparator)
 
     def bind_processor(self, dialect):
+        json_serializer = dialect._json_serializer or json.dumps
         if util.py2k:
             encoding = dialect.encoding
             def process(value):
-                return self.json_serializer(value).encode(encoding)
+                return json_serializer(value).encode(encoding)
         else:
             def process(value):
-                return self.json_serializer(value)
+                return json_serializer(value)
         return process
 
     def result_processor(self, dialect, coltype):
+        json_deserializer = dialect._json_deserializer or json.loads
         if util.py2k:
             encoding = dialect.encoding
             def process(value):
-                return self.json_deserializer(value.decode(encoding))
+                return json_deserializer(value.decode(encoding))
         else:
             def process(value):
-                return self.json_deserializer(value)
+                return json_deserializer(value)
         return process
 
 
index ceb04b5801d60e882bb436ac3ed04761c9091245..f5da8a711f5dae2309a46226b7de52454afdfc55 100644 (file)
@@ -428,6 +428,11 @@ class PGDialect_psycopg2(PGDialect):
                                         array_oid=array_oid)
             fns.append(on_connect)
 
+        if self.dbapi and self._json_deserializer:
+            def on_connect(conn):
+                extras.register_default_json(conn, loads=self._json_deserializer)
+            fns.append(on_connect)
+
         if fns:
             def on_connect(conn):
                 for fn in fns:
index 062c708a5717ce49bc2203c0276c34b321f2fd4c..bcb3e1ebb15103952258556d0da541f0226e1a46 100644 (file)
@@ -998,9 +998,8 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase):
         )
 
     def test_bind_serialize_default(self):
-        from sqlalchemy.engine import default
 
-        dialect = default.DefaultDialect()
+        dialect = postgresql.dialect()
         proc = self.test_table.c.hash.type._cached_bind_processor(dialect)
         eq_(
             proc(util.OrderedDict([("key1", "value1"), ("key2", "value2")])),
@@ -1008,9 +1007,7 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase):
         )
 
     def test_bind_serialize_with_slashes_and_quotes(self):
-        from sqlalchemy.engine import default
-
-        dialect = default.DefaultDialect()
+        dialect = postgresql.dialect()
         proc = self.test_table.c.hash.type._cached_bind_processor(dialect)
         eq_(
             proc({'\\"a': '\\"1'}),
@@ -1018,9 +1015,7 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase):
         )
 
     def test_parse_error(self):
-        from sqlalchemy.engine import default
-
-        dialect = default.DefaultDialect()
+        dialect = postgresql.dialect()
         proc = self.test_table.c.hash.type._cached_result_processor(
                     dialect, None)
         assert_raises_message(
@@ -1033,9 +1028,7 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase):
         )
 
     def test_result_deserialize_default(self):
-        from sqlalchemy.engine import default
-
-        dialect = default.DefaultDialect()
+        dialect = postgresql.dialect()
         proc = self.test_table.c.hash.type._cached_result_processor(
                     dialect, None)
         eq_(
@@ -1044,9 +1037,7 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase):
         )
 
     def test_result_deserialize_with_slashes_and_quotes(self):
-        from sqlalchemy.engine import default
-
-        dialect = default.DefaultDialect()
+        dialect = postgresql.dialect()
         proc = self.test_table.c.hash.type._cached_result_processor(
                     dialect, None)
         eq_(
@@ -1693,9 +1684,7 @@ class JSONTest(AssertsCompiledSQL, fixtures.TestBase):
         )
 
     def test_bind_serialize_default(self):
-        from sqlalchemy.engine import default
-
-        dialect = default.DefaultDialect()
+        dialect = postgresql.dialect()
         proc = self.test_table.c.test_column.type._cached_bind_processor(dialect)
         eq_(
             proc({"A": [1, 2, 3, True, False]}),
@@ -1703,9 +1692,7 @@ class JSONTest(AssertsCompiledSQL, fixtures.TestBase):
         )
 
     def test_result_deserialize_default(self):
-        from sqlalchemy.engine import default
-
-        dialect = default.DefaultDialect()
+        dialect = postgresql.dialect()
         proc = self.test_table.c.test_column.type._cached_result_processor(
                     dialect, None)
         eq_(
@@ -1782,16 +1769,26 @@ class JSONRoundTripTest(fixtures.TablesTest):
         )
         self._assert_data([{"k1": "r1v1", "k2": "r1v2"}])
 
-    def _non_native_engine(self):
+    def _non_native_engine(self, json_serializer=None, json_deserializer=None):
+        if json_serializer is not None or json_deserializer is not None:
+            options = {
+                "json_serializer": json_serializer,
+                "json_deserializer": json_deserializer
+            }
+        else:
+            options = {}
+
         if testing.against("postgresql+psycopg2"):
             from psycopg2.extras import register_default_json
-            engine = engines.testing_engine()
+            engine = engines.testing_engine(options=options)
             @event.listens_for(engine, "connect")
             def connect(dbapi_connection, connection_record):
                 engine.dialect._has_native_json = False
                 def pass_(value):
                     return value
                 register_default_json(dbapi_connection, loads=pass_)
+        elif options:
+            engine = engines.testing_engine(options=options)
         else:
             engine = testing.db
         engine.connect()
@@ -1811,6 +1808,56 @@ class JSONRoundTripTest(fixtures.TablesTest):
         engine = self._non_native_engine()
         self._test_insert(engine)
 
+
+    def _test_custom_serialize_deserialize(self, native):
+        import json
+        def loads(value):
+            value = json.loads(value)
+            value['x'] = value['x'] + '_loads'
+            return value
+
+        def dumps(value):
+            value = dict(value)
+            value['x'] = 'dumps_y'
+            return json.dumps(value)
+
+        if native:
+            engine = engines.testing_engine(options=dict(
+                            json_serializer=dumps,
+                            json_deserializer=loads
+                        ))
+        else:
+            engine = self._non_native_engine(
+                            json_serializer=dumps,
+                            json_deserializer=loads
+                        )
+
+        s = select([
+                cast(
+                    {
+                        "key": "value",
+                        "x": "q"
+                    },
+                    JSON
+                )
+            ])
+        eq_(
+            engine.scalar(s),
+            {
+                "key": "value",
+                "x": "dumps_y_loads"
+            },
+        )
+
+    @testing.only_on("postgresql+psycopg2")
+    def test_custom_native(self):
+        self._test_custom_serialize_deserialize(True)
+
+    @testing.only_on("postgresql+psycopg2")
+    def test_custom_python(self):
+        self._test_custom_serialize_deserialize(False)
+
+
     @testing.only_on("postgresql+psycopg2")
     def test_criterion_native(self):
         engine = testing.db