]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
support functions "as binary comparison"
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 9 Jul 2018 19:47:14 +0000 (15:47 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 10 Jul 2018 01:48:39 +0000 (21:48 -0400)
Added new feature :meth:`.FunctionElement.as_comparison` which allows a SQL
function to act as a binary comparison operation that can work within the
ORM.

Change-Id: I07018e2065d09775c0406cabdd35fc38cc0da699
Fixes: #3831
doc/build/changelog/migration_13.rst
doc/build/changelog/unreleased_13/3831.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/functions.py
lib/sqlalchemy/sql/operators.py
test/orm/test_relationships.py
test/sql/test_functions.py

index 71ef4d0113348bf61a66ac5305e882363a97ecf2..6201435cc5bfa2d5b54f49b49256f24e7a506848 100644 (file)
@@ -90,6 +90,66 @@ and can't easily be generalized for more complex queries.
 New Features and Improvements - Core
 ====================================
 
+.. _change_3831:
+
+Binary comparison interpretation for SQL functions
+--------------------------------------------------
+
+This enhancement is implemented at the Core level, however is applicable
+primarily to the ORM.
+
+A SQL function that compares two elements can now be used as a "comparison"
+object, suitable for usage in an ORM :func:`.relationship`, by first
+creating the function as usual using the :data:`.func` factory, then
+when the function is complete calling upon the :meth:`.FunctionElement.as_comparison`
+modifier to produce a :class:`.BinaryExpression` that has a "left" and a "right"
+side::
+
+    class Venue(Base):
+        __tablename__ = 'venue'
+        id = Column(Integer, primary_key=True)
+        name = Column(String)
+
+        descendants = relationship(
+            "Venue",
+            primaryjoin=func.instr(
+                remote(foreign(name)), name + "/"
+            ).as_comparison(1, 2) == 1,
+            viewonly=True,
+            order_by=name
+        )
+
+Above, the :paramref:`.relationship.primaryjoin` of the "descendants" relationship
+will produce a "left" and a "right" expression based on the first and second
+arguments passed to ``instr()``.   This allows features like the ORM
+lazyload to produce SQL like::
+
+    SELECT venue.id AS venue_id, venue.name AS venue_name
+    FROM venue
+    WHERE instr(venue.name, (? || ?)) = ? ORDER BY venue.name
+    ('parent1', '/', 1)
+
+and a joinedload, such as::
+
+    v1 = s.query(Venue).filter_by(name="parent1").options(
+        joinedload(Venue.descendants)).one()
+
+to work as::
+
+    SELECT venue.id AS venue_id, venue.name AS venue_name,
+      venue_1.id AS venue_1_id, venue_1.name AS venue_1_name
+    FROM venue LEFT OUTER JOIN venue AS venue_1
+      ON instr(venue_1.name, (venue.name || ?)) = ?
+    WHERE venue.name = ? ORDER BY venue_1.name
+    ('/', 1, 'parent1')
+
+This feature is expected to help with situations such as making use of
+geometric functions in relationship join conditions, or any case where
+the ON clause of the SQL join is expressed in terms of a SQL function.
+
+:ticket:`3831`
+
+
 Key Behavioral Changes - Core
 =============================
 
diff --git a/doc/build/changelog/unreleased_13/3831.rst b/doc/build/changelog/unreleased_13/3831.rst
new file mode 100644 (file)
index 0000000..8df8f5c
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: feature, sql
+    :tickets: 3831
+
+    Added new feature :meth:`.FunctionElement.as_comparison` which allows a SQL
+    function to act as a binary comparison operation that can work within the
+    ORM.
+
+    .. seealso::
+
+        :ref:`change_3831`
index 75827b34e1ba5c2a25d8faecca3af8db5a45b778..ae1dd2c7cb05cd178578f5dddb1fc4d377562a37 100644 (file)
@@ -1079,6 +1079,9 @@ class SQLCompiler(Compiled):
             else:
                 return self._generate_generic_binary(binary, opstring, **kw)
 
+    def visit_function_as_comparison_op_binary(self, element, operator, **kw):
+        return self.process(element.sql_function, **kw)
+
     def visit_mod_binary(self, binary, operator, **kw):
         if self.preparer._double_percents:
             return self.process(binary.left, **kw) + " %% " + \
index 78aeb3a0171ab96e5c2f4152e34edc35268b7e49..27d030d4ff4dc3fc6fee1cf4f71855ede2d593a2 100644 (file)
@@ -12,7 +12,8 @@ from . import sqltypes, schema
 from .base import Executable, ColumnCollection
 from .elements import ClauseList, Cast, Extract, _literal_as_binds, \
     literal_column, _type_from_args, ColumnElement, _clone,\
-    Over, BindParameter, FunctionFilter, Grouping, WithinGroup
+    Over, BindParameter, FunctionFilter, Grouping, WithinGroup, \
+    BinaryExpression
 from .selectable import FromClause, Select, Alias
 from . import util as sqlutil
 from . import operators
@@ -166,6 +167,73 @@ class FunctionElement(Executable, ColumnElement, FromClause):
             return self
         return FunctionFilter(self, *criterion)
 
+    def as_comparison(self, left_index, right_index):
+        """Interpret this expression as a boolean comparison between two values.
+
+        A hypothetical SQL function "is_equal()" which compares to values
+        for equality would be written in the Core expression language as::
+
+            expr = func.is_equal("a", "b")
+
+        If "is_equal()" above is comparing "a" and "b" for equality, the
+        :meth:`.FunctionElement.as_comparison` method would be invoked as::
+
+            expr = func.is_equal("a", "b").as_comparison(1, 2)
+
+        Where above, the integer value "1" refers to the first argument of the
+        "is_equal()" function and the integer value "2" refers to the second.
+
+        This would create a :class:`.BinaryExpression` that is equivalent to::
+
+            BinaryExpression("a", "b", operator=op.eq)
+
+        However, at the SQL level it would still render as
+        "is_equal('a', 'b')".
+
+        The ORM, when it loads a related object or collection, needs to be able
+        to manipulate the "left" and "right" sides of the ON clause of a JOIN
+        expression. The purpose of this method is to provide a SQL function
+        construct that can also supply this information to the ORM, when used
+        with the :paramref:`.relationship.primaryjoin` parameter.  The return
+        value is a containment object called :class:`.FunctionAsBinary`.
+
+        An ORM example is as follows::
+
+            class Venue(Base):
+                __tablename__ = 'venue'
+                id = Column(Integer, primary_key=True)
+                name = Column(String)
+
+                descendants = relationship(
+                    "Venue",
+                    primaryjoin=func.instr(
+                        remote(foreign(name)), name + "/"
+                    ).as_comparison(1, 2) == 1,
+                    viewonly=True,
+                    order_by=name
+                )
+
+        Above, the "Venue" class can load descendant "Venue" objects by
+        determining if the name of the parent Venue is contained within the
+        start of the hypothetical descendant value's name, e.g. "parent1" would
+        match up to "parent1/child1", but not to "parent2/child1".
+
+        Possible use cases include the "materialized path" example given above,
+        as well as making use of special SQL functions such as geometric
+        functions to create join conditions.
+
+        :param left_index: the integer 1-based index of the function argument
+         that serves as the "left" side of the expression.
+        :param right_index: the integer 1-based index of the function argument
+         that serves as the "right" side of the expression.
+
+        .. versionadded:: 1.3
+
+        """
+        return FunctionAsBinary(
+            self, left_index, right_index
+        )
+
     @property
     def _from_objects(self):
         return self.clauses._from_objects
@@ -281,6 +349,41 @@ class FunctionElement(Executable, ColumnElement, FromClause):
             return super(FunctionElement, self).self_group(against=against)
 
 
+class FunctionAsBinary(BinaryExpression):
+
+    def __init__(self, fn, left_index, right_index):
+        left = fn.clauses.clauses[left_index - 1]
+        right = fn.clauses.clauses[right_index - 1]
+        self.sql_function = fn
+        self.left_index = left_index
+        self.right_index = right_index
+
+        super(FunctionAsBinary, self).__init__(
+            left, right, operators.function_as_comparison_op,
+            type_=sqltypes.BOOLEANTYPE)
+
+    @property
+    def left(self):
+        return self.sql_function.clauses.clauses[self.left_index - 1]
+
+    @left.setter
+    def left(self, value):
+        self.sql_function.clauses.clauses[self.left_index - 1] = value
+
+    @property
+    def right(self):
+        return self.sql_function.clauses.clauses[self.right_index - 1]
+
+    @right.setter
+    def right(self, value):
+        self.sql_function.clauses.clauses[self.right_index - 1] = value
+
+    def _copy_internals(self, **kw):
+        clone = kw.pop('clone')
+        self.sql_function = clone(self.sql_function, **kw)
+        super(FunctionAsBinary, self)._copy_internals(**kw)
+
+
 class _FunctionGenerator(object):
     """Generate :class:`.Function` objects based on getattr calls."""
 
index 11d19545507b54cdc27bfe1ca12a20d9148b1098..bda9a0c8658357d5a298f8dc08a336555f8b20bd 100644 (file)
@@ -1079,6 +1079,10 @@ def from_():
     raise NotImplementedError()
 
 
+def function_as_comparison_op():
+    raise NotImplementedError()
+
+
 def as_():
     raise NotImplementedError()
 
@@ -1260,7 +1264,8 @@ def json_path_getitem_op(a, b):
 _commutative = {eq, ne, add, mul}
 
 _comparison = {eq, ne, lt, gt, ge, le, between_op, like_op, is_,
-               isnot, is_distinct_from, isnot_distinct_from}
+               isnot, is_distinct_from, isnot_distinct_from,
+               function_as_comparison_op}
 
 
 def is_comparison(op):
@@ -1314,6 +1319,7 @@ _largest = util.symbol('_largest', canonical=100)
 
 _PRECEDENCE = {
     from_: 15,
+    function_as_comparison_op: 15,
     any_op: 15,
     all_op: 15,
     getitem: 15,
index 77acae264f3a4516723b574c3d1108c3f60b391e..ce6d77d91fdc07cb2b46c414365e10273f1af299 100644 (file)
@@ -2838,6 +2838,66 @@ class ViewOnlyComplexJoin(_RelationshipErrors, fixtures.MappedTest):
         self._assert_raises_no_local_remote(configure_mappers, "T1.t3s")
 
 
+class FunctionAsPrimaryJoinTest(fixtures.DeclarativeMappedTest):
+    """test :ticket:`3831`
+
+    """
+
+    __only_on__= 'sqlite'
+
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        class Venue(Base):
+            __tablename__ = 'venue'
+            id = Column(Integer, primary_key=True)
+            name = Column(String)
+
+            descendants = relationship(
+                "Venue",
+                primaryjoin=func.instr(
+                    remote(foreign(name)), name + "/").as_comparison(1, 2) == 1,
+                viewonly=True,
+                order_by=name
+            )
+
+    @classmethod
+    def insert_data(cls):
+        Venue = cls.classes.Venue
+        s = Session()
+        s.add_all([
+            Venue(name="parent1"),
+            Venue(name="parent2"),
+            Venue(name="parent1/child1"),
+            Venue(name="parent1/child2"),
+            Venue(name="parent2/child1"),
+        ])
+        s.commit()
+
+    def test_lazyload(self):
+        Venue = self.classes.Venue
+        s = Session()
+        v1 = s.query(Venue).filter_by(name="parent1").one()
+        eq_(
+            [d.name for d in v1.descendants],
+            ['parent1/child1', 'parent1/child2'])
+
+    def test_joinedload(self):
+        Venue = self.classes.Venue
+        s = Session()
+
+        def go():
+            v1 = s.query(Venue).filter_by(name="parent1").\
+                options(joinedload(Venue.descendants)).one()
+
+            eq_(
+                [d.name for d in v1.descendants],
+                ['parent1/child1', 'parent1/child2'])
+
+        self.assert_sql_count(testing.db, go, 1)
+
+
 class RemoteForeignBetweenColsTest(fixtures.DeclarativeMappedTest):
 
     """test a complex annotation using between().
index cbc02f4b869f39a43c0f38ea5fec5d7da8d7ef21..3032c3ce3afd92f19315bb44bedda5b306545d7a 100644 (file)
@@ -2,7 +2,7 @@ from sqlalchemy.testing import eq_, is_
 import datetime
 from sqlalchemy import func, select, Integer, literal, DateTime, Table, \
     Column, Sequence, MetaData, extract, Date, String, bindparam, \
-    literal_column, ARRAY, Numeric
+    literal_column, ARRAY, Numeric, Boolean
 from sqlalchemy.sql import table, column
 from sqlalchemy import sql, util
 from sqlalchemy.sql.compiler import BIND_TEMPLATES
@@ -589,6 +589,70 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             MissingType().compile
         )
 
+    def test_as_comparison(self):
+
+        fn = func.substring("foo", "foobar").as_comparison(1, 2)
+        is_(fn.type._type_affinity, Boolean)
+
+        self.assert_compile(
+            fn.left, ":substring_1",
+            checkparams={'substring_1': 'foo'})
+        self.assert_compile(
+            fn.right, ":substring_1",
+            checkparams={'substring_1': 'foobar'})
+
+        self.assert_compile(
+            fn, "substring(:substring_1, :substring_2)",
+            checkparams={"substring_1": "foo", "substring_2": "foobar"})
+
+    def test_as_comparison_annotate(self):
+
+        fn = func.foobar("x", "y", "q", "p", "r").as_comparison(2, 5)
+
+        from sqlalchemy.sql import annotation
+        fn_annotated = annotation._deep_annotate(fn, {"token": "yes"})
+
+        eq_(fn.left._annotations, {})
+        eq_(fn_annotated.left._annotations, {"token": "yes"})
+
+    def test_as_comparison_many_argument(self):
+
+        fn = func.some_comparison("x", "y", "z", "p", "q", "r").as_comparison(2, 5)
+        is_(fn.type._type_affinity, Boolean)
+
+        self.assert_compile(
+            fn.left, ":some_comparison_1",
+            checkparams={"some_comparison_1": "y"})
+        self.assert_compile(
+            fn.right, ":some_comparison_1",
+            checkparams={"some_comparison_1": "q"})
+
+        from sqlalchemy.sql import visitors
+
+        fn_2 = visitors.cloned_traverse(fn, {}, {})
+        fn_2.right = literal_column("ABC")
+
+        self.assert_compile(
+            fn,
+            "some_comparison(:some_comparison_1, :some_comparison_2, "
+            ":some_comparison_3, "
+            ":some_comparison_4, :some_comparison_5, :some_comparison_6)",
+            checkparams={
+                'some_comparison_1': 'x', 'some_comparison_2': 'y',
+                'some_comparison_3': 'z', 'some_comparison_4': 'p',
+                'some_comparison_5': 'q', 'some_comparison_6': 'r'})
+
+        self.assert_compile(
+            fn_2,
+            "some_comparison(:some_comparison_1, :some_comparison_2, "
+            ":some_comparison_3, "
+            ":some_comparison_4, ABC, :some_comparison_5)",
+            checkparams={
+                'some_comparison_1': 'x', 'some_comparison_2': 'y',
+                'some_comparison_3': 'z', 'some_comparison_4': 'p',
+                'some_comparison_5': 'r'}
+        )
+
 
 class ReturnTypeTest(AssertsCompiledSQL, fixtures.TestBase):