]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add PostgreSQL HStore type support
authorAudrius Kažukauskas <audrius@neutrino.lt>
Tue, 13 Nov 2012 14:43:41 +0000 (16:43 +0200)
committerAudrius Kažukauskas <audrius@neutrino.lt>
Tue, 13 Nov 2012 14:43:41 +0000 (16:43 +0200)
lib/sqlalchemy/dialects/postgresql/__init__.py
lib/sqlalchemy/dialects/postgresql/hstore.py [new file with mode: 0644]
test/dialect/test_postgresql.py

index 3c273bd56e3366c6764fc557e774c1c317f8027d..2a1a07cbd7955e77ebca670330e8931035e7b370 100644 (file)
@@ -9,12 +9,15 @@ from . import base, psycopg2, pg8000, pypostgresql, zxjdbc
 base.dialect = psycopg2.dialect
 
 from .base import \
-    INTEGER, BIGINT, SMALLINT, VARCHAR, CHAR, TEXT, NUMERIC, FLOAT, REAL, INET, \
-    CIDR, UUID, BIT, MACADDR, DOUBLE_PRECISION, TIMESTAMP, TIME,\
+    INTEGER, BIGINT, SMALLINT, VARCHAR, CHAR, TEXT, NUMERIC, FLOAT, REAL, \
+    INET, CIDR, UUID, BIT, MACADDR, DOUBLE_PRECISION, TIMESTAMP, TIME, \
     DATE, BYTEA, BOOLEAN, INTERVAL, ARRAY, ENUM, dialect, array
+from .hstore import HSTORE, hstore, HStoreSyntaxError
 
 __all__ = (
-'INTEGER', 'BIGINT', 'SMALLINT', 'VARCHAR', 'CHAR', 'TEXT', 'NUMERIC', 'FLOAT', 'REAL', 'INET',
-'CIDR', 'UUID', 'BIT', 'MACADDR', 'DOUBLE_PRECISION', 'TIMESTAMP', 'TIME',
-'DATE', 'BYTEA', 'BOOLEAN', 'INTERVAL', 'ARRAY', 'ENUM', 'dialect', 'array'
+    'INTEGER', 'BIGINT', 'SMALLINT', 'VARCHAR', 'CHAR', 'TEXT', 'NUMERIC',
+    'FLOAT', 'REAL', 'INET', 'CIDR', 'UUID', 'BIT', 'MACADDR',
+    'DOUBLE_PRECISION', 'TIMESTAMP', 'TIME', 'DATE', 'BYTEA', 'BOOLEAN',
+    'INTERVAL', 'ARRAY', 'ENUM', 'dialect', 'array', 'HSTORE', 'hstore',
+    'HStoreSyntaxError'
 )
diff --git a/lib/sqlalchemy/dialects/postgresql/hstore.py b/lib/sqlalchemy/dialects/postgresql/hstore.py
new file mode 100644 (file)
index 0000000..4797031
--- /dev/null
@@ -0,0 +1,306 @@
+# postgresql/hstore.py
+# Copyright (C) 2005-2012 the SQLAlchemy authors and contributors <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+import re
+
+from .base import ARRAY
+from ... import types as sqltypes
+from ...sql import functions as sqlfunc
+from ...sql.operators import custom_op
+from ...exc import SQLAlchemyError
+from ...ext.mutable import Mutable
+
+__all__ = ('HStoreSyntaxError', 'HSTORE', 'hstore')
+
+# My best guess at the parsing rules of hstore literals, since no formal
+# grammar is given.  This is mostly reverse engineered from PG's input parser
+# behavior.
+HSTORE_PAIR_RE = re.compile(r"""
+(
+  "(?P<key> (\\ . | [^"])* )"       # Quoted key
+)
+[ ]* => [ ]*    # Pair operator, optional adjoining whitespace
+(
+    (?P<value_null> NULL )          # NULL value
+  | "(?P<value> (\\ . | [^"])* )"   # Quoted value
+)
+""", re.VERBOSE)
+
+HSTORE_DELIMITER_RE = re.compile(r"""
+[ ]* , [ ]*
+""", re.VERBOSE)
+
+
+class HStoreSyntaxError(SQLAlchemyError):
+    """Indicates an error unmarshalling an hstore value."""
+
+    def __init__(self, hstore_str, pos):
+        self.hstore_str = hstore_str
+        self.pos = pos
+
+        ctx = 20
+        hslen = len(hstore_str)
+
+        parsed_tail = hstore_str[max(pos - ctx - 1, 0):min(pos, hslen)]
+        residual = hstore_str[min(pos, hslen):min(pos + ctx + 1, hslen)]
+
+        if len(parsed_tail) > ctx:
+            parsed_tail = '[...]' + parsed_tail[1:]
+        if len(residual) > ctx:
+            residual = residual[:-1] + '[...]'
+
+        super(HStoreSyntaxError, self).__init__(
+            "After %r, could not parse residual at position %d: %r" %
+            (parsed_tail, pos, residual)
+        )
+
+
+def _parse_hstore(hstore_str):
+    """Parse an hstore from it's literal string representation.
+
+    Attempts to approximate PG's hstore input parsing rules as closely as
+    possible. Although currently this is not strictly necessary, since the
+    current implementation of hstore's output syntax is stricter than what it
+    accepts as input, the documentation makes no guarantees that will always
+    be the case.
+
+    Throws HStoreSyntaxError if parsing fails.
+
+    """
+    result = {}
+    pos = 0
+    pair_match = HSTORE_PAIR_RE.match(hstore_str)
+
+    while pair_match is not None:
+        key = pair_match.group('key')
+        if pair_match.group('value_null'):
+            value = None
+        else:
+            value = pair_match.group('value').replace(r'\"', '"')
+        result[key] = value
+
+        pos += pair_match.end()
+
+        delim_match = HSTORE_DELIMITER_RE.match(hstore_str[pos:])
+        if delim_match is not None:
+            pos += delim_match.end()
+
+        pair_match = HSTORE_PAIR_RE.match(hstore_str[pos:])
+
+    if pos != len(hstore_str):
+        raise HStoreSyntaxError(hstore_str, pos)
+
+    return result
+
+
+def _serialize_hstore(val):
+    """Serialize a dictionary into an hstore literal.  Keys and values must
+    both be strings (except None for values).
+
+    """
+    def esc(s, position):
+        if position == 'value' and s is None:
+            return 'NULL'
+        elif isinstance(s, basestring):
+            return '"%s"' % s.replace('"', r'\"')
+        else:
+            raise ValueError("%r in %s position is not a string." %
+                             (s, position))
+
+    return ', '.join('%s=>%s' % (esc(k, 'key'), esc(v, 'value'))
+                     for k, v in val.iteritems())
+
+
+class MutationDict(Mutable, dict):
+    def __setitem__(self, key, value):
+        """Detect dictionary set events and emit change events."""
+        dict.__setitem__(self, key, value)
+        self.changed()
+
+    def __delitem__(self, key, value):
+        """Detect dictionary del events and emit change events."""
+        dict.__delitem__(self, key, value)
+        self.changed()
+
+    @classmethod
+    def coerce(cls, key, value):
+        """Convert plain dictionary to MutationDict."""
+        if not isinstance(value, MutationDict):
+            if isinstance(value, dict):
+                return MutationDict(value)
+            return Mutable.coerce(key, value)
+        else:
+            return value
+
+    def __getstate__(self):
+        return dict(self)
+
+    def __setstate__(self, state):
+        self.update(state)
+
+
+class HSTORE(sqltypes.Concatenable, sqltypes.UserDefinedType):
+    """The column type for representing PostgreSQL's contrib/hstore type.  This
+    type is a miniature key-value store in a column.  It supports query
+    operators for all the usual operations on a map-like data structure.
+
+    """
+    class comparator_factory(sqltypes.UserDefinedType.Comparator):
+        def has_key(self, other):
+            """Boolean expression.  Test for presence of a key.  Note that the
+            key may be a SQLA expression.
+            """
+            return self.expr.op('?')(other)
+
+        def has_all(self, other):
+            """Boolean expression.  Test for presence of all keys in the PG
+            array.
+            """
+            return self.expr.op('?&')(other)
+
+        def has_any(self, other):
+            """Boolean expression.  Test for presence of any key in the PG
+            array.
+            """
+            return self.expr.op('?|')(other)
+
+        def defined(self, key):
+            """Boolean expression.  Test for presence of a non-NULL value for
+            the key.  Note that the key may be a SQLA expression.
+            """
+            return _HStoreDefinedFunction(self.expr, key)
+
+        def contains(self, other, **kwargs):
+            """Boolean expression.  Test if keys are a superset of the keys of
+            the argument hstore expression.
+            """
+            return self.expr.op('@>')(other)
+
+        def contained_by(self, other):
+            """Boolean expression.  Test if keys are a proper subset of the
+            keys of the argument hstore expression.
+            """
+            return self.expr.op('<@')(other)
+
+        def __getitem__(self, other):
+            """Text expression.  Get the value at a given key.  Note that the
+            key may be a SQLA expression.
+            """
+            return self.expr.op('->', precedence=5)(other)
+
+        def __add__(self, other):
+            """HStore expression.  Merge the left and right hstore expressions,
+            with duplicate keys taking the value from the right expression.
+            """
+            return self.expr.concat(other)
+
+        def delete(self, key):
+            """HStore expression.  Returns the contents of this hstore with the
+            given key deleted.  Note that the key may be a SQLA expression.
+            """
+            if isinstance(key, dict):
+                key = _serialize_hstore(key)
+            return _HStoreDeleteFunction(self.expr, key)
+
+        def slice(self, array):
+            """HStore expression.  Returns a subset of an hstore defined by
+            array of keys.
+            """
+            return _HStoreSliceFunction(self.expr, array)
+
+        def keys(self):
+            """Text array expression.  Returns array of keys."""
+            return _HStoreKeysFunction(self.expr)
+
+        def vals(self):
+            """Text array expression.  Returns array of values."""
+            return _HStoreValsFunction(self.expr)
+
+        def array(self):
+            """Text array expression.  Returns array of alternating keys and
+            values.
+            """
+            return _HStoreArrayFunction(self.expr)
+
+        def matrix(self):
+            """Text array expression.  Returns array of [key, value] pairs."""
+            return _HStoreMatrixFunction(self.expr)
+
+        def _adapt_expression(self, op, other_comparator):
+            if isinstance(op, custom_op):
+                if op.opstring in ['?', '?&', '?|', '@>', '<@']:
+                    return op, sqltypes.Boolean
+                elif op.opstring == '->':
+                    return op, sqltypes.Text
+            return op, other_comparator.type
+
+    def bind_processor(self, dialect):
+        def process(value):
+            if isinstance(value, dict):
+                return _serialize_hstore(value)
+            else:
+                return value
+        return process
+
+    def get_col_spec(self):
+        return 'HSTORE'
+
+    def result_processor(self, dialect, coltype):
+        def process(value):
+            if value is not None:
+                return _parse_hstore(value)
+            else:
+                return value
+        return process
+
+MutationDict.associate_with(HSTORE)
+
+
+class hstore(sqlfunc.GenericFunction):
+    """Construct an hstore on the server side using the hstore function.
+
+    The single argument or a pair of arguments are evaluated as SQLAlchemy
+    expressions, so both may contain columns, function calls, or any other
+    valid SQL expressions which evaluate to text or array.
+
+    """
+    type = HSTORE
+    name = 'hstore'
+
+
+class _HStoreDefinedFunction(sqlfunc.GenericFunction):
+    type = sqltypes.Boolean
+    name = 'defined'
+
+
+class _HStoreDeleteFunction(sqlfunc.GenericFunction):
+    type = HSTORE
+    name = 'delete'
+
+
+class _HStoreSliceFunction(sqlfunc.GenericFunction):
+    type = HSTORE
+    name = 'slice'
+
+
+class _HStoreKeysFunction(sqlfunc.GenericFunction):
+    type = ARRAY(sqltypes.Text)
+    name = 'akeys'
+
+
+class _HStoreValsFunction(sqlfunc.GenericFunction):
+    type = ARRAY(sqltypes.Text)
+    name = 'avals'
+
+
+class _HStoreArrayFunction(sqlfunc.GenericFunction):
+    type = ARRAY(sqltypes.Text)
+    name = 'hstore_to_array'
+
+
+class _HStoreMatrixFunction(sqlfunc.GenericFunction):
+    type = ARRAY(sqltypes.Text)
+    name = 'hstore_to_matrix'
index 3be005f36a855c753e9dc2e8de8b09929529efcd..33753b48f65d112639345ca681202ae5eed1d233 100644 (file)
@@ -13,14 +13,16 @@ from sqlalchemy import Table, Column, select, MetaData, text, Integer, \
             PrimaryKeyConstraint, DateTime, tuple_, Float, BigInteger, \
             func, literal_column, literal, bindparam, cast, extract, \
             SmallInteger, Enum, REAL, update, insert, Index, delete, \
-            and_, Date, TypeDecorator, Time, Unicode, Interval, or_
+            and_, Date, TypeDecorator, Time, Unicode, Interval, or_, Text
 from sqlalchemy.orm import Session, mapper, aliased
 from sqlalchemy import exc, schema, types
 from sqlalchemy.dialects.postgresql import base as postgresql
+from sqlalchemy.dialects.postgresql import HSTORE, hstore
 from sqlalchemy.util.compat import decimal
 from sqlalchemy.testing.util import round_decimal
 from sqlalchemy.sql import table, column
 import logging
+import re
 
 class SequenceTest(fixtures.TestBase, AssertsCompiledSQL):
 
@@ -2707,3 +2709,191 @@ class TupleTest(fixtures.TestBase):
                 ).scalar(),
                 exp
             )
+
+
+class HStoreTest(fixtures.TestBase):
+    def _assert_sql(self, construct, expected):
+        dialect = postgresql.dialect()
+        compiled = str(construct.compile(dialect=dialect))
+        compiled = re.sub(r'\s+', ' ', compiled)
+        expected = re.sub(r'\s+', ' ', expected)
+        eq_(compiled, expected)
+
+    def setup(self):
+        metadata = MetaData()
+        self.test_table = Table('test_table', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('hash', HSTORE)
+        )
+        self.hashcol = self.test_table.c.hash
+
+    def _test_where(self, whereclause, expected):
+        stmt = select([self.test_table]).where(whereclause)
+        self._assert_sql(
+            stmt,
+            "SELECT test_table.id, test_table.hash FROM test_table "
+            "WHERE %s" % expected
+        )
+
+    def _test_cols(self, colclause, expected, from_=True):
+        stmt = select([colclause])
+        self._assert_sql(
+            stmt,
+            (
+                "SELECT %s" +
+                (" FROM test_table" if from_ else "")
+            ) % expected
+        )
+
+    def test_where_has_key(self):
+        self._test_where(
+            self.hashcol.has_key('foo'),
+            "test_table.hash ? %(hash_1)s"
+        )
+
+    def test_where_has_all(self):
+        self._test_where(
+            self.hashcol.has_all(postgresql.array(['1', '2'])),
+            "test_table.hash ?& ARRAY[%(param_1)s, %(param_2)s]"
+        )
+
+    def test_where_has_any(self):
+        self._test_where(
+            self.hashcol.has_any(postgresql.array(['1', '2'])),
+            "test_table.hash ?| ARRAY[%(param_1)s, %(param_2)s]"
+        )
+
+    def test_where_defined(self):
+        self._test_where(
+            self.hashcol.defined('foo'),
+            "defined(test_table.hash, %(param_1)s)"
+        )
+
+    def test_where_contains(self):
+        self._test_where(
+            self.hashcol.contains({'foo': '1'}),
+            "test_table.hash @> %(hash_1)s"
+        )
+
+    def test_where_contained_by(self):
+        self._test_where(
+            self.hashcol.contained_by({'foo': '1', 'bar': None}),
+            "test_table.hash <@ %(hash_1)s"
+        )
+
+    def test_where_getitem(self):
+        self._test_where(
+            self.hashcol['bar'] == None,
+            "(test_table.hash -> %(hash_1)s) IS NULL"
+        )
+
+    def test_cols_get(self):
+        self._test_cols(
+            self.hashcol['foo'],
+            "test_table.hash -> %(hash_1)s AS anon_1",
+            True
+        )
+
+    def test_cols_delete_single_key(self):
+        self._test_cols(
+            self.hashcol.delete('foo'),
+            "delete(test_table.hash, %(param_1)s) AS delete_1",
+            True
+        )
+
+    def test_cols_delete_array_of_keys(self):
+        self._test_cols(
+            self.hashcol.delete(postgresql.array(['foo', 'bar'])),
+            ("delete(test_table.hash, ARRAY[%(param_1)s, %(param_2)s]) "
+             "AS delete_1"),
+            True
+        )
+
+    def test_cols_delete_matching_pairs(self):
+        self._test_cols(
+            self.hashcol.delete(hstore('1', '2')),
+            ("delete(test_table.hash, hstore(%(param_1)s, %(param_2)s)) "
+             "AS delete_1"),
+            True
+        )
+
+    def test_cols_slice(self):
+        self._test_cols(
+            self.hashcol.slice(postgresql.array(['1', '2'])),
+            ("slice(test_table.hash, ARRAY[%(param_1)s, %(param_2)s]) "
+             "AS slice_1"),
+            True
+        )
+
+    def test_cols_hstore_pair_text(self):
+        self._test_cols(
+            hstore('foo', '3')['foo'],
+            "hstore(%(param_1)s, %(param_2)s) -> %(hstore_1)s AS anon_1",
+            False
+        )
+
+    def test_cols_hstore_pair_array(self):
+        self._test_cols(
+            hstore(postgresql.array(['1', '2']),
+                   postgresql.array(['3', None]))['1'],
+            ("hstore(ARRAY[%(param_1)s, %(param_2)s], "
+             "ARRAY[%(param_3)s, NULL]) -> %(hstore_1)s AS anon_1"),
+            False
+        )
+
+    def test_cols_hstore_single_array(self):
+        self._test_cols(
+            hstore(postgresql.array(['1', '2', '3', None]))['3'],
+            ("hstore(ARRAY[%(param_1)s, %(param_2)s, %(param_3)s, NULL]) "
+             "-> %(hstore_1)s AS anon_1"),
+            False
+        )
+
+    def test_cols_concat(self):
+        self._test_cols(
+            self.hashcol.concat(hstore(cast(self.test_table.c.id, Text), '3')),
+            ("test_table.hash || hstore(CAST(test_table.id AS TEXT), "
+             "%(param_1)s) AS anon_1"),
+            True
+        )
+
+    def test_cols_concat_op(self):
+        self._test_cols(
+            self.hashcol + self.hashcol,
+            "test_table.hash || test_table.hash AS anon_1",
+            True
+        )
+
+    def test_cols_concat_get(self):
+        self._test_cols(
+            (self.hashcol + self.hashcol)['foo'],
+            "test_table.hash || test_table.hash -> %(param_1)s AS anon_1"
+        )
+
+    def test_cols_keys(self):
+        self._test_cols(
+            self.hashcol.keys(),
+            "akeys(test_table.hash) AS akeys_1",
+            True
+        )
+
+    def test_cols_vals(self):
+        self._test_cols(
+            self.hashcol.vals(),
+            "avals(test_table.hash) AS avals_1",
+            True
+        )
+
+    def test_cols_array(self):
+        self._test_cols(
+            self.hashcol.array(),
+            "hstore_to_array(test_table.hash) AS hstore_to_array_1",
+            True
+        )
+
+    def test_cols_matrix(self):
+        self._test_cols(
+            self.hashcol.matrix(),
+            "hstore_to_matrix(test_table.hash) AS hstore_to_matrix_1",
+            True
+        )