]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- hstore adjustments
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 18 Nov 2012 01:45:17 +0000 (20:45 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 18 Nov 2012 01:45:17 +0000 (20:45 -0500)
doc/build/changelog/changelog_08.rst
doc/build/orm/extensions/mutable.rst
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/hstore.py
lib/sqlalchemy/dialects/postgresql/psycopg2.py
lib/sqlalchemy/ext/mutable.py
lib/sqlalchemy/testing/exclusions.py
test/dialect/test_postgresql.py
test/ext/test_mutable.py
test/requirements.py

index 6b4e49bc34aaeb9d248c6c21b9fd07521381cbb7..7ec712a76e1ac77e001ec584dd46802469fa4d6f 100644 (file)
@@ -6,6 +6,14 @@
 .. changelog::
     :version: 0.8.0b2
 
+    .. change::
+        :tags: postgresql, hstore
+        :tickets: 2606
+
+      :class:`.HSTORE` is now available in the Postgresql dialect.
+      Will also use psycopg2's extensions if available.  Courtesy
+      Audrius Kažukauskas.
+
     .. change::
         :tags: sybase, feature
         :tickets: 1753
index b0428f951959989e28ad2a66579e106a3cd7e8a4..259055980df0652e9b487d6afaf83789c7abe3e0 100644 (file)
@@ -19,6 +19,9 @@ API Reference
     :show-inheritance:
     :members:
 
+.. autoclass:: MutableDict
+       :show-inheritance:
+       :members:
 
 
 
index 625ece6a131bb9d0ded36873793cdd8741bfda29..f1061c90b95d482f4c64950d025d5eff062bdd7c 100644 (file)
@@ -968,6 +968,9 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
     def visit_BIGINT(self, type_):
         return "BIGINT"
 
+    def visit_HSTORE(self, type_):
+        return "HSTORE"
+
     def visit_datetime(self, type_):
         return self.visit_TIMESTAMP(type_)
 
index 4797031fa051dca1718d9a2376831f89c79ae1a5..ee5dec168e83ee59f02d814e6b08082af9e2fc55 100644 (file)
@@ -11,7 +11,6 @@ from ... import types as sqltypes
 from ...sql import functions as sqlfunc
 from ...sql.operators import custom_op
 from ...exc import SQLAlchemyError
-from ...ext.mutable import Mutable
 
 __all__ = ('HStoreSyntaxError', 'HSTORE', 'hstore')
 
@@ -114,41 +113,56 @@ def _serialize_hstore(val):
                      for k, v in val.iteritems())
 
 
-class MutationDict(Mutable, dict):
-    def __setitem__(self, key, value):
-        """Detect dictionary set events and emit change events."""
-        dict.__setitem__(self, key, value)
-        self.changed()
+class HSTORE(sqltypes.Concatenable, sqltypes.TypeEngine):
+    """Represent the Postgresql HSTORE type.
 
-    def __delitem__(self, key, value):
-        """Detect dictionary del events and emit change events."""
-        dict.__delitem__(self, key, value)
-        self.changed()
+    The :class:`.HSTORE` type stores dictionaries containing strings, e.g.::
 
-    @classmethod
-    def coerce(cls, key, value):
-        """Convert plain dictionary to MutationDict."""
-        if not isinstance(value, MutationDict):
-            if isinstance(value, dict):
-                return MutationDict(value)
-            return Mutable.coerce(key, value)
-        else:
-            return value
+        data_table = Table('data_table', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('data', HSTORE)
+        )
+
+        with engine.connect() as conn:
+            conn.execute(
+                data_table.insert(),
+                data = {"key1": "value1", "key2": "value2"}
+            )
+
+    :class:`.HSTORE` provides for a wide range of operations, including:
+
+    * :meth:`.HSTORE.comparatopr_factory.has_key`
+
+    * :meth:`.HSTORE.comparatopr_factory.has_all`
+
+    * :meth:`.HSTORE.comparatopr_factory.defined`
 
-    def __getstate__(self):
-        return dict(self)
+    For usage with the SQLAlchemy ORM, it may be desirable to combine
+    the usage of :class:`.HSTORE` with the :mod:`sqlalchemy.ext.mutable`
+    extension.  This extension will allow in-place changes to dictionary
+    values to be detected by the unit of work::
 
-    def __setstate__(self, state):
-        self.update(state)
+        from sqlalchemy.ext.mutable import Mutable
 
+        class MyClass(Base):
+            __tablename__ = 'data_table'
 
-class HSTORE(sqltypes.Concatenable, sqltypes.UserDefinedType):
-    """The column type for representing PostgreSQL's contrib/hstore type.  This
-    type is a miniature key-value store in a column.  It supports query
-    operators for all the usual operations on a map-like data structure.
+            id = Column(Integer, primary_key=True)
+            data = Column(Mutable.as_mutable(HSTORE))
+
+        my_object = session.query(MyClass).one()
+
+        # in-place mutation, requires Mutable extension
+        # in order for the ORM to detect
+        my_object.data['some_key'] = 'some value'
+
+        session.commit()
 
     """
-    class comparator_factory(sqltypes.UserDefinedType.Comparator):
+
+    __visit_name__ = 'HSTORE'
+
+    class comparator_factory(sqltypes.TypeEngine.Comparator):
         def has_key(self, other):
             """Boolean expression.  Test for presence of a key.  Note that the
             key may be a SQLA expression.
@@ -237,6 +251,15 @@ class HSTORE(sqltypes.Concatenable, sqltypes.UserDefinedType):
                     return op, sqltypes.Text
             return op, other_comparator.type
 
+    #@util.memoized_property
+    #@property
+    #def _expression_adaptations(self):
+    #    return {
+    #        operators.getitem: {
+    #           sqltypes.String: sqltypes.String
+    #        },
+    #    }
+
     def bind_processor(self, dialect):
         def process(value):
             if isinstance(value, dict):
@@ -245,9 +268,6 @@ class HSTORE(sqltypes.Concatenable, sqltypes.UserDefinedType):
                 return value
         return process
 
-    def get_col_spec(self):
-        return 'HSTORE'
-
     def result_processor(self, dialect, coltype):
         def process(value):
             if value is not None:
@@ -256,8 +276,6 @@ class HSTORE(sqltypes.Concatenable, sqltypes.UserDefinedType):
                 return value
         return process
 
-MutationDict.associate_with(HSTORE)
-
 
 class hstore(sqlfunc.GenericFunction):
     """Construct an hstore on the server side using the hstore function.
index 05286ce203a5fcdc58041c66bd157b35f50e1f24..73f7123287083a0c08c8d600c445c04c749469fb 100644 (file)
@@ -131,9 +131,16 @@ The psycopg2 dialect will log Postgresql NOTICE messages via the
     import logging
     logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO)
 
+HSTORE type
+------------
 
-"""
+The psycopg2 dialect will make use of the
+``psycopg2.extensions.register_hstore()`` extension when using the HSTORE
+type.  This replaces SQLAlchemy's pure-Python HSTORE coercion which takes
+effect for other DBAPIs.
 
+"""
+from __future__ import absolute_import
 import re
 import logging
 
@@ -198,10 +205,16 @@ class _PGArray(ARRAY):
 
 class _PGHStore(HSTORE):
     def bind_processor(self, dialect):
-        return None
+        if dialect._has_native_hstore:
+            return None
+        else:
+            return super(_PGHStore, self).bind_processor(dialect)
 
     def result_processor(self, dialect, coltype):
-        return None
+        if dialect._has_native_hstore:
+            return None
+        else:
+            return super(_PGHStore, self).result_processor(dialect, coltype)
 
 # When we're handed literal SQL, ensure it's a SELECT-query. Since
 # 8.3, combining cursors and "FOR UPDATE" has been fine.
@@ -283,6 +296,8 @@ class PGDialect_psycopg2(PGDialect):
     preparer = PGIdentifierPreparer_psycopg2
     psycopg2_version = (0, 0)
 
+    _has_native_hstore = False
+
     colspecs = util.update_copy(
         PGDialect.colspecs,
         {
@@ -295,10 +310,13 @@ class PGDialect_psycopg2(PGDialect):
     )
 
     def __init__(self, server_side_cursors=False, use_native_unicode=True,
-                        client_encoding=None, **kwargs):
+                        client_encoding=None,
+                        use_native_hstore=True,
+                        **kwargs):
         PGDialect.__init__(self, **kwargs)
         self.server_side_cursors = server_side_cursors
         self.use_native_unicode = use_native_unicode
+        self.use_native_hstore = use_native_hstore
         self.supports_unicode_binds = use_native_unicode
         self.client_encoding = client_encoding
         if self.dbapi and hasattr(self.dbapi, '__version__'):
@@ -309,21 +327,18 @@ class PGDialect_psycopg2(PGDialect):
                                             int(x)
                                             for x in m.group(1, 2, 3)
                                             if x is not None)
-        self._hstore_oids = None
+
 
     def initialize(self, connection):
         super(PGDialect_psycopg2, self).initialize(connection)
-
-        if self.psycopg2_version >= (2, 4):
-            extras = __import__('psycopg2.extras').extras
-            oids = extras.HstoreAdapter.get_oids(connection.connection)
-            if oids is not None and oids[0]:
-                self._hstore_oids = oids[0], oids[1]
+        self._has_native_hstore = self.use_native_hstore and \
+                        self._hstore_oids(connection.connection) \
+                            is not None
 
     @classmethod
     def dbapi(cls):
-        psycopg = __import__('psycopg2')
-        return psycopg
+        import psycopg2
+        return psycopg2
 
     @util.memoized_property
     def _isolation_lookup(self):
@@ -348,6 +363,8 @@ class PGDialect_psycopg2(PGDialect):
         connection.set_isolation_level(level)
 
     def on_connect(self):
+        from psycopg2 import extras, extensions
+
         fns = []
         if self.client_encoding is not None:
             def on_connect(conn):
@@ -360,17 +377,17 @@ class PGDialect_psycopg2(PGDialect):
             fns.append(on_connect)
 
         if self.dbapi and self.use_native_unicode:
-            extensions = __import__('psycopg2.extensions').extensions
             def on_connect(conn):
                 extensions.register_type(extensions.UNICODE, conn)
             fns.append(on_connect)
 
-        extras = __import__('psycopg2.extras').extras
-        def on_connect(conn):
-            if self._hstore_oids is not None:
-                oid, array_oid = self._hstore_oids
-                extras.register_hstore(conn, oid=oid, array_oid=array_oid)
-        fns.append(on_connect)
+        if self.dbapi and self.use_native_hstore:
+            def on_connect(conn):
+                hstore_oids = self._hstore_oids(conn)
+                if hstore_oids is not None:
+                    oid, array_oid = hstore_oids
+                    extras.register_hstore(conn, oid=oid, array_oid=array_oid)
+            fns.append(on_connect)
 
         if fns:
             def on_connect(conn):
@@ -380,6 +397,15 @@ class PGDialect_psycopg2(PGDialect):
         else:
             return None
 
+    @util.memoized_instancemethod
+    def _hstore_oids(self, conn):
+        if self.psycopg2_version >= (2, 4):
+            from psycopg2 import extras
+            oids = extras.HstoreAdapter.get_oids(conn)
+            if oids is not None and oids[0]:
+                return oids[0:2]
+        return None
+
     def create_connect_args(self, url):
         opts = url.translate_connect_args(username='user')
         if 'port' in opts:
index 2f5c68d7b10feb7ce39b8557299bf24179568e49..fcc493875f09d1b44bde3e2bbbee495c357cc520 100644 (file)
@@ -49,21 +49,21 @@ with any type whose target Python type may be mutable, including
 :class:`.PickleType`, :class:`.postgresql.ARRAY`, etc.
 
 When using the :mod:`sqlalchemy.ext.mutable` extension, the value itself
-tracks all parents which reference it.  Here we will replace the usage
-of plain Python dictionaries with a dict subclass that implements
-the :class:`.Mutable` mixin::
+tracks all parents which reference it.  Below, we illustrate the a simple
+version of the :class:`.MutableDict` dictionary object, which applies
+the :class:`.Mutable` mixin to a plain Python dictionary::
 
     import collections
     from sqlalchemy.ext.mutable import Mutable
 
-    class MutationDict(Mutable, dict):
+    class MutableDict(Mutable, dict):
         @classmethod
         def coerce(cls, key, value):
-            "Convert plain dictionaries to MutationDict."
+            "Convert plain dictionaries to MutableDict."
 
-            if not isinstance(value, MutationDict):
+            if not isinstance(value, MutableDict):
                 if isinstance(value, dict):
-                    return MutationDict(value)
+                    return MutableDict(value)
 
                 # this call will raise ValueError
                 return Mutable.coerce(key, value)
@@ -84,23 +84,23 @@ the :class:`.Mutable` mixin::
 
 The above dictionary class takes the approach of subclassing the Python
 built-in ``dict`` to produce a dict
-subclass which routes all mutation events through ``__setitem__``. There are
-many variants on this approach, such as subclassing ``UserDict.UserDict``,
-the newer ``collections.MutableMapping``,  etc. The part that's important to this
+subclass which routes all mutation events through ``__setitem__``.  There are
+variants on this approach, such as subclassing ``UserDict.UserDict`` or
+``collections.MutableMapping``; the part that's important to this
 example is that the :meth:`.Mutable.changed` method is called whenever an in-place change to the
 datastructure takes place.
 
 We also redefine the :meth:`.Mutable.coerce` method which will be used to
-convert any values that are not instances of ``MutationDict``, such
+convert any values that are not instances of ``MutableDict``, such
 as the plain dictionaries returned by the ``json`` module, into the
 appropriate type.  Defining this method is optional; we could just as well created our
-``JSONEncodedDict`` such that it always returns an instance of ``MutationDict``,
-and additionally ensured that all calling code uses ``MutationDict``
+``JSONEncodedDict`` such that it always returns an instance of ``MutableDict``,
+and additionally ensured that all calling code uses ``MutableDict``
 explicitly.  When :meth:`.Mutable.coerce` is not overridden, any values
 applied to a parent object which are not instances of the mutable type
 will raise a ``ValueError``.
 
-Our new ``MutationDict`` type offers a class method
+Our new ``MutableDict`` type offers a class method
 :meth:`~.Mutable.as_mutable` which we can use within column metadata
 to associate with types. This method grabs the given type object or
 class and associates a listener that will detect all future mappings
@@ -111,7 +111,7 @@ attribute. Such as, with classical table metadata::
 
     my_data = Table('my_data', metadata,
         Column('id', Integer, primary_key=True),
-        Column('data', MutationDict.as_mutable(JSONEncodedDict))
+        Column('data', MutableDict.as_mutable(JSONEncodedDict))
     )
 
 Above, :meth:`~.Mutable.as_mutable` returns an instance of ``JSONEncodedDict``
@@ -139,7 +139,7 @@ There's no difference in usage when using declarative::
     class MyDataClass(Base):
         __tablename__ = 'my_data'
         id = Column(Integer, primary_key=True)
-        data = Column(MutationDict.as_mutable(JSONEncodedDict))
+        data = Column(MutableDict.as_mutable(JSONEncodedDict))
 
 Any in-place changes to the ``MyDataClass.data`` member
 will flag the attribute as "dirty" on the parent object::
@@ -155,13 +155,13 @@ will flag the attribute as "dirty" on the parent object::
     >>> assert m1 in sess.dirty
     True
 
-The ``MutationDict`` can be associated with all future instances
+The ``MutableDict`` can be associated with all future instances
 of ``JSONEncodedDict`` in one step, using :meth:`~.Mutable.associate_with`.  This
 is similar to :meth:`~.Mutable.as_mutable` except it will intercept
-all occurrences of ``MutationDict`` in all mappings unconditionally, without
+all occurrences of ``MutableDict`` in all mappings unconditionally, without
 the need to declare it individually::
 
-    MutationDict.associate_with(JSONEncodedDict)
+    MutableDict.associate_with(JSONEncodedDict)
 
     class MyDataClass(Base):
         __tablename__ = 'my_data'
@@ -193,7 +193,7 @@ stream::
 With our dictionary example, we need to return the contents of the dict itself
 (and also restore them on __setstate__)::
 
-    class MutationDict(Mutable, dict):
+    class MutableDict(Mutable, dict):
         # ....
 
         def __getstate__(self):
@@ -330,6 +330,7 @@ from ..orm.attributes import flag_modified
 from .. import event, types
 from ..orm import mapper, object_mapper
 from ..util import memoized_property
+from .. import exc
 import weakref
 
 class MutableBase(object):
@@ -459,6 +460,9 @@ class Mutable(MutableBase):
 
         """
 
+        if not isinstance(sqltype, types.TypeEngine):
+            raise exc.ArgumentError("Type instance expected, got %s" % sqltype)
+
         def listen_for_type(mapper, class_):
             for prop in mapper.iterate_properties:
                 if hasattr(prop, 'columns'):
@@ -562,3 +566,37 @@ class MutableComposite(MutableBase):
 
         event.listen(mapper, 'mapper_configured', listen_for_type)
 
+
+
+class MutableDict(Mutable, dict):
+    """A dictionary type that implements :class:`.Mutable`.
+
+    .. versionadded:: 0.8
+
+    """
+
+    def __setitem__(self, key, value):
+        """Detect dictionary set events and emit change events."""
+        dict.__setitem__(self, key, value)
+        self.changed()
+
+    def __delitem__(self, key, value):
+        """Detect dictionary del events and emit change events."""
+        dict.__delitem__(self, key, value)
+        self.changed()
+
+    @classmethod
+    def coerce(cls, key, value):
+        """Convert plain dictionary to MutableDict."""
+        if not isinstance(value, MutableDict):
+            if isinstance(value, dict):
+                return MutableDict(value)
+            return Mutable.coerce(key, value)
+        else:
+            return value
+
+    def __getstate__(self):
+        return dict(self)
+
+    def __setstate__(self, state):
+        self.update(state)
index 31ad26a4eccce8df7c32549089a4354084289be6..3c70ec8d9a198acb5948db71ee8541212f8ae60b 100644 (file)
@@ -283,7 +283,7 @@ def closed():
 def future(fn, *args, **kw):
     return fails_if(LambdaPredicate(fn, *args, **kw), "Future feature")
 
-def fails_on(db, reason):
+def fails_on(db, reason=None):
     return fails_if(SpecPredicate(db), reason)
 
 def fails_on_everything_except(*dbs):
@@ -293,16 +293,16 @@ def fails_on_everything_except(*dbs):
                     ])
             )
 
-def skip(db, reason):
+def skip(db, reason=None):
     return skip_if(SpecPredicate(db), reason)
 
-def only_on(dbs, reason):
+def only_on(dbs, reason=None):
     return only_if(
             OrPredicate([SpecPredicate(db) for db in util.to_list(dbs)])
     )
 
 
-def exclude(db, op, spec, reason):
+def exclude(db, op, spec, reason=None):
     return skip_if(SpecPredicate(db, op, spec), reason)
 
 
index 33753b48f65d112639345ca681202ae5eed1d233..25b6bcb54a4292dfb786098ec3cf5d337e5de2a5 100644 (file)
@@ -2745,6 +2745,61 @@ class HStoreTest(fixtures.TestBase):
             ) % expected
         )
 
+    def test_bind_serialize_default(self):
+        from sqlalchemy.engine import default
+
+        dialect = default.DefaultDialect()
+        proc = self.test_table.c.hash.type._cached_bind_processor(dialect)
+        eq_(
+            proc({"key1": "value1", "key2": "value2"}),
+            '"key2"=>"value2", "key1"=>"value1"'
+        )
+
+    def test_result_deserialize_default(self):
+        from sqlalchemy.engine import default
+
+        dialect = default.DefaultDialect()
+        proc = self.test_table.c.hash.type._cached_result_processor(
+                    dialect, None)
+        eq_(
+            proc('"key2"=>"value2", "key1"=>"value1"'),
+            {"key1": "value1", "key2": "value2"}
+        )
+
+    def test_bind_serialize_psycopg2(self):
+        from sqlalchemy.dialects.postgresql import psycopg2
+
+        dialect = psycopg2.PGDialect_psycopg2()
+        dialect._has_native_hstore = True
+        proc = self.test_table.c.hash.type._cached_bind_processor(dialect)
+        is_(proc, None)
+
+        dialect = psycopg2.PGDialect_psycopg2()
+        dialect._has_native_hstore = False
+        proc = self.test_table.c.hash.type._cached_bind_processor(dialect)
+        eq_(
+            proc({"key1": "value1", "key2": "value2"}),
+            '"key2"=>"value2", "key1"=>"value1"'
+        )
+
+    def test_result_deserialize_psycopg2(self):
+        from sqlalchemy.dialects.postgresql import psycopg2
+
+        dialect = psycopg2.PGDialect_psycopg2()
+        dialect._has_native_hstore = True
+        proc = self.test_table.c.hash.type._cached_result_processor(
+                    dialect, None)
+        is_(proc, None)
+
+        dialect = psycopg2.PGDialect_psycopg2()
+        dialect._has_native_hstore = False
+        proc = self.test_table.c.hash.type._cached_result_processor(
+                    dialect, None)
+        eq_(
+            proc('"key2"=>"value2", "key1"=>"value1"'),
+            {"key1": "value1", "key2": "value2"}
+        )
+
     def test_where_has_key(self):
         self._test_where(
             self.hashcol.has_key('foo'),
@@ -2897,3 +2952,76 @@ class HStoreTest(fixtures.TestBase):
             "hstore_to_matrix(test_table.hash) AS hstore_to_matrix_1",
             True
         )
+
+class HStoreRoundTripTest(fixtures.TablesTest):
+    #__only_on__ = 'postgresql'
+    __requires__ = 'hstore',
+    __dialect__ = postgresql.dialect()
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table('data_table', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('name', String(30), nullable=False),
+            Column('data', HSTORE)
+        )
+
+    def _fixture_data(self, engine):
+        data_table = self.tables.data_table
+        engine.execute(
+                data_table.insert(),
+                {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}},
+                {'name': 'r2', 'data': {"k1": "r2v1", "k2": "r2v2"}},
+                {'name': 'r3', 'data': {"k1": "r3v1", "k2": "r3v2"}},
+                {'name': 'r4', 'data': {"k1": "r4v1", "k2": "r4v2"}},
+                {'name': 'r5', 'data': {"k1": "r5v1", "k2": "r5v2"}},
+        )
+
+    def _assert_data(self, compare):
+        data = testing.db.execute(
+            select([self.tables.data_table.c.data]).
+                order_by(self.tables.data_table.c.name)
+        ).fetchall()
+        eq_([d for d, in data], compare)
+
+    def _test_insert(self, engine):
+        engine.execute(
+            self.tables.data_table.insert(),
+            {'name': 'r1', 'data': {"k1": "r1v1", "k2": "r1v2"}}
+        )
+        self._assert_data([{"k1": "r1v1", "k2": "r1v2"}])
+
+    def _non_native_engine(self):
+        if testing.against("postgresql+psycopg2"):
+            engine = engines.testing_engine(options=dict(use_native_hstore=False))
+        else:
+            engine = testing.db
+        engine.connect()
+        return engine
+
+    @testing.only_on("postgresql+psycopg2")
+    def test_insert_native(self):
+        engine = testing.db
+        self._test_insert(engine)
+
+    def test_insert_python(self):
+        engine = self._non_native_engine()
+        self._test_insert(engine)
+
+    @testing.only_on("postgresql+psycopg2")
+    def test_criterion_native(self):
+        engine = testing.db
+        self._fixture_data(engine)
+        self._test_criterion(engine)
+
+    def test_criterion_python(self):
+        engine = self._non_native_engine()
+        self._fixture_data(engine)
+        self._test_criterion(engine)
+
+    def _test_criterion(self, engine):
+        data_table = self.tables.data_table
+        result = engine.execute(
+            select([data_table.c.data]).where(data_table.c.data['k1'] == 'r3v1')
+        ).first()
+        eq_(result, ({'k1': 'r3v1', 'k2': 'r3v2'},))
index f56ce40378d324c1fde2eb2fe9ceabbf5f873ae8..916ff9d4b175cb463e115a68383b7d9f202232b1 100644 (file)
@@ -28,40 +28,14 @@ class FooWithEq(object):
     def __eq__(self, other):
         return self.id == other.id
 
+from sqlalchemy.ext.mutable import MutableDict
+
 class _MutableDictTestBase(object):
     run_define_tables = 'each'
 
     @classmethod
     def _type_fixture(cls):
-        from sqlalchemy.ext.mutable import Mutable
-
-        # needed for pickle support
-        global MutationDict
-
-        class MutationDict(Mutable, dict):
-            @classmethod
-            def coerce(cls, key, value):
-                if not isinstance(value, MutationDict):
-                    if isinstance(value, dict):
-                        return MutationDict(value)
-                    return Mutable.coerce(key, value)
-                else:
-                    return value
-
-            def __getstate__(self):
-                return dict(self)
-
-            def __setstate__(self, state):
-                self.update(state)
-
-            def __setitem__(self, key, value):
-                dict.__setitem__(self, key, value)
-                self.changed()
-
-            def __delitem__(self, key):
-                dict.__delitem__(self, key)
-                self.changed()
-        return MutationDict
+        return MutableDict
 
     def setup_mappers(cls):
         foo = cls.tables.foo
@@ -152,9 +126,9 @@ class _MutableDictTestBase(object):
 class MutableWithScalarPickleTest(_MutableDictTestBase, fixtures.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
-        MutationDict = cls._type_fixture()
+        MutableDict = cls._type_fixture()
 
-        mutable_pickle = MutationDict.as_mutable(PickleType)
+        mutable_pickle = MutableDict.as_mutable(PickleType)
         Table('foo', metadata,
             Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('skip', mutable_pickle),
@@ -188,11 +162,11 @@ class MutableWithScalarJSONTest(_MutableDictTestBase, fixtures.MappedTest):
                     value = json.loads(value)
                 return value
 
-        MutationDict = cls._type_fixture()
+        MutableDict = cls._type_fixture()
 
         Table('foo', metadata,
             Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
-            Column('data', MutationDict.as_mutable(JSONEncodedDict)),
+            Column('data', MutableDict.as_mutable(JSONEncodedDict)),
             Column('non_mutable_data', JSONEncodedDict),
             Column('unrelated_data', String(50))
         )
@@ -203,7 +177,7 @@ class MutableWithScalarJSONTest(_MutableDictTestBase, fixtures.MappedTest):
 class MutableAssocWithAttrInheritTest(_MutableDictTestBase, fixtures.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
-        MutationDict = cls._type_fixture()
+        MutableDict = cls._type_fixture()
 
         Table('foo', metadata,
             Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
@@ -222,7 +196,7 @@ class MutableAssocWithAttrInheritTest(_MutableDictTestBase, fixtures.MappedTest)
 
         mapper(Foo, foo)
         mapper(SubFoo, subfoo, inherits=Foo)
-        MutationDict.associate_with_attribute(Foo.data)
+        MutableDict.associate_with_attribute(Foo.data)
 
     def test_in_place_mutation(self):
         sess = Session()
@@ -249,8 +223,8 @@ class MutableAssocWithAttrInheritTest(_MutableDictTestBase, fixtures.MappedTest)
 class MutableAssociationScalarPickleTest(_MutableDictTestBase, fixtures.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
-        MutationDict = cls._type_fixture()
-        MutationDict.associate_with(PickleType)
+        MutableDict = cls._type_fixture()
+        MutableDict.associate_with(PickleType)
 
         Table('foo', metadata,
             Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
@@ -281,8 +255,8 @@ class MutableAssociationScalarJSONTest(_MutableDictTestBase, fixtures.MappedTest
                     value = json.loads(value)
                 return value
 
-        MutationDict = cls._type_fixture()
-        MutationDict.associate_with(JSONEncodedDict)
+        MutableDict = cls._type_fixture()
+        MutableDict.associate_with(JSONEncodedDict)
 
         Table('foo', metadata,
             Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
index 1cc7b058d8ed736f3ecd30e9c31ea0b4c27c1985..202cd8466744ccb944e04b22016a26865aa54cc3 100644 (file)
@@ -480,6 +480,19 @@ class DefaultRequirements(SuiteRequirements):
         """
         return self.cpython
 
+    @property
+    def hstore(self):
+        def check_hstore():
+            if not against("postgresql"):
+                return False
+            try:
+                self.db.execute("SELECT 'a=>1,a=>2'::hstore;")
+                return True
+            except:
+                return False
+
+        return only_if(check_hstore)
+
     @property
     def sqlite(self):
         return skip_if(lambda: not self._has_sqlite())