]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add new "expanding" feature to bindparam()
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 3 Apr 2017 18:34:58 +0000 (14:34 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 7 Apr 2017 19:53:49 +0000 (15:53 -0400)
Added a new kind of :func:`.bindparam` called "expanding".  This is
for use in ``IN`` expressions where the list of elements is rendered
into individual bound parameters at statement execution time, rather
than at statement compilation time.  This allows both a single bound
parameter name to be linked to an IN expression of multiple elements,
as well as allows query caching to be used with IN expressions.  The
new feature allows the related features of "select in" loading and
"polymorphic in" loading to make use of the baked query extension
to reduce call overhead.   This feature should be considered to be
**experimental** for 1.2.

Fixes: #3953
Change-Id: Ie708414a3ab9c0af29998a2c7f239ff7633b1f6e

12 files changed:
doc/build/changelog/changelog_12.rst
doc/build/changelog/migration_12.rst
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/default_comparator.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/testing/requirements.py
lib/sqlalchemy/testing/suite/test_select.py
test/requirements.py
test/sql/test_compiler.py
test/sql/test_operators.py
test/sql/test_query.py

index 2070b92c03e8391fd2537efdb8a6224ba9bb76f9..689323137cf0db29b3b5062da73e79910c5ff67f 100644 (file)
 
             :ref:`change_2626`
 
+    .. change:: 3953
+        :tags: feature, sql
+        :tickets: 3953
+
+        Added a new kind of :func:`.bindparam` called "expanding".  This is
+        for use in ``IN`` expressions where the list of elements is rendered
+        into individual bound parameters at statement execution time, rather
+        than at statement compilation time.  This allows both a single bound
+        parameter name to be linked to an IN expression of multiple elements,
+        as well as allows query caching to be used with IN expressions.  The
+        new feature allows the related features of "select in" loading and
+        "polymorphic in" loading to make use of the baked query extension
+        to reduce call overhead.   This feature should be considered to be
+        **experimental** for 1.2.
+
+        .. seealso::
+
+            :ref:`change_3953`
+
     .. change:: 3923
         :tags: bug, sql
         :tickets: 3923
index 1dda856696a65a50eb7ab687725862665b647ad8..e36e0af0877230b5d500974099b6c025b3283aba 100644 (file)
@@ -309,6 +309,30 @@ warning.   However, it is anticipated that most users will appreciate the
 
 :ticket:`3907`
 
+.. _change_3953:
+
+Late-expanded IN parameter sets allow IN expressions with cached statements
+---------------------------------------------------------------------------
+
+Added a new kind of :func:`.bindparam` called "expanding".  This is
+for use in ``IN`` expressions where the list of elements is rendered
+into individual bound parameters at statement execution time, rather
+than at statement compilation time.  This allows both a single bound
+parameter name to be linked to an IN expression of multiple elements,
+as well as allows query caching to be used with IN expressions.  The
+new feature allows the related features of "select in" loading and
+"polymorphic in" loading to make use of the baked query extension
+to reduce call overhead::
+
+    stmt = select([table]).where(
+        table.c.col.in_(bindparam('foo', expanding=True))
+    conn.execute(stmt, {"foo": [1, 2, 3]})
+
+The feature should be regarded as **experimental** within the 1.2 series.
+
+
+:ticket:`3953`
+
 .. _change_1546:
 
 Support for SQL Comments on Table, Column, includes DDL, reflection
index 628e23c9e6c868949b1ca717b6c31652a683cc20..d1b54ab01a0a568e921573fbd852c0b7a535fa4e 100644 (file)
@@ -552,6 +552,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
     # result column names
     _translate_colname = None
 
+    _expanded_parameters = util.immutabledict()
+
     @classmethod
     def _init_ddl(cls, dialect, connection, dbapi_connection, compiled_ddl):
         """Initialize execution context for a DDLElement construct."""
@@ -645,6 +647,11 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
 
         processors = compiled._bind_processors
 
+        if compiled.contains_expanding_parameters:
+            positiontup = self._expand_in_parameters(compiled, processors)
+        elif compiled.positional:
+            positiontup = self.compiled.positiontup
+
         # Convert the dictionary of bind parameter values
         # into a dict or list to be sent to the DBAPI's
         # execute() or executemany() method.
@@ -652,7 +659,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
         if compiled.positional:
             for compiled_params in self.compiled_parameters:
                 param = []
-                for key in self.compiled.positiontup:
+                for key in positiontup:
                     if key in processors:
                         param.append(processors[key](compiled_params[key]))
                     else:
@@ -684,10 +691,97 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
                     )
 
                 parameters.append(param)
+
         self.parameters = dialect.execute_sequence_format(parameters)
 
         return self
 
+    def _expand_in_parameters(self, compiled, processors):
+        """handle special 'expanding' parameters, IN tuples that are rendered
+        on a per-parameter basis for an otherwise fixed SQL statement string.
+
+        """
+        if self.executemany:
+            raise exc.InvalidRequestError(
+                "'expanding' parameters can't be used with "
+                "executemany()")
+
+        if self.compiled.positional and self.compiled._numeric_binds:
+            # I'm not familiar with any DBAPI that uses 'numeric'
+            raise NotImplementedError(
+                "'expanding' bind parameters not supported with "
+                "'numeric' paramstyle at this time.")
+
+        self._expanded_parameters = {}
+
+        compiled_params = self.compiled_parameters[0]
+        if compiled.positional:
+            positiontup = []
+        else:
+            positiontup = None
+
+        replacement_expressions = {}
+        for name in (
+            self.compiled.positiontup if compiled.positional
+            else self.compiled.binds
+        ):
+            parameter = self.compiled.binds[name]
+            if parameter.expanding:
+                values = compiled_params.pop(name)
+                if not values:
+                    raise exc.InvalidRequestError(
+                        "'expanding' parameters can't be used with an "
+                        "empty list"
+                    )
+                elif isinstance(values[0], (tuple, list)):
+                    to_update = [
+                        ("%s_%s_%s" % (name, i, j), value)
+                        for i, tuple_element in enumerate(values, 1)
+                        for j, value in enumerate(tuple_element, 1)
+                    ]
+                    replacement_expressions[name] = ", ".join(
+                        "(%s)" % ", ".join(
+                            self.compiled.bindtemplate % {
+                                "name":
+                                to_update[i * len(tuple_element) + j][0]
+                            }
+                            for j, value in enumerate(tuple_element)
+                        )
+                        for i, tuple_element in enumerate(values)
+
+                    )
+                else:
+                    to_update = [
+                        ("%s_%s" % (name, i), value)
+                        for i, value in enumerate(values, 1)
+                    ]
+                    replacement_expressions[name] = ", ".join(
+                        self.compiled.bindtemplate % {
+                            "name": key}
+                        for key, value in to_update
+                    )
+                compiled_params.update(to_update)
+                processors.update(
+                    (key, processors[name])
+                    for key in to_update if name in processors
+                )
+                if compiled.positional:
+                    positiontup.extend(name for name, value in to_update)
+                self._expanded_parameters[name] = [
+                    expand_key for expand_key, value in to_update]
+            elif compiled.positional:
+                positiontup.append(name)
+
+        def process_expanding(m):
+            return replacement_expressions.pop(m.group(1))
+
+        self.statement = re.sub(
+            r"\[EXPANDING_(.+)\]",
+            process_expanding,
+            self.statement
+        )
+        return positiontup
+
     @classmethod
     def _init_statement(cls, dialect, connection, dbapi_connection,
                         statement, parameters):
@@ -1039,7 +1133,11 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
                     get_dbapi_type(self.dialect.dbapi)
                 if dbtype is not None and \
                         (not exclude_types or dbtype not in exclude_types):
-                    inputsizes.append(dbtype)
+                    if key in self._expanded_parameters:
+                        inputsizes.extend(
+                            [dbtype] * len(self._expanded_parameters[key]))
+                    else:
+                        inputsizes.append(dbtype)
             try:
                 self.cursor.setinputsizes(*inputsizes)
             except BaseException as e:
@@ -1054,10 +1152,19 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
                 if dbtype is not None and \
                         (not exclude_types or dbtype not in exclude_types):
                     if translate:
+                        # TODO: this part won't work w/ the
+                        # expanded_parameters feature, e.g. for cx_oracle
+                        # quoted bound names
                         key = translate.get(key, key)
                     if not self.dialect.supports_unicode_binds:
                         key = self.dialect._encoder(key)[0]
-                    inputsizes[key] = dbtype
+                    if key in self._expanded_parameters:
+                        inputsizes.update(
+                            (expand_key, dbtype) for expand_key
+                            in self._expanded_parameters[key]
+                        )
+                    else:
+                        inputsizes[key] = dbtype
             try:
                 self.cursor.setinputsizes(**inputsizes)
             except BaseException as e:
index cc42480096d76f6ae2a588e078cb573bc65ccea7..6da0647970e4acf47993509e38d17cb4e2e4b0d1 100644 (file)
@@ -350,6 +350,14 @@ class SQLCompiler(Compiled):
     columns with the table name (i.e. MySQL only)
     """
 
+    contains_expanding_parameters = False
+    """True if we've encountered bindparam(..., expanding=True).
+
+    These need to be converted before execution time against the
+    string statement.
+
+    """
+
     ansi_bind_rules = False
     """SQL 92 doesn't allow bind parameters to be used
     in the columns clause of a SELECT, nor does it allow
@@ -370,8 +378,14 @@ class SQLCompiler(Compiled):
     True unless using an unordered TextAsFrom.
     """
 
-    insert_prefetch = update_prefetch = ()
+    _numeric_binds = False
+    """
+    True if paramstyle is "numeric".  This paramstyle is trickier than
+    all the others.
 
+    """
+
+    insert_prefetch = update_prefetch = ()
 
     def __init__(self, dialect, statement, column_keys=None,
                  inline=False, **kwargs):
@@ -418,6 +432,7 @@ class SQLCompiler(Compiled):
         self.positional = dialect.positional
         if self.positional:
             self.positiontup = []
+            self._numeric_binds = dialect.paramstyle == "numeric"
         self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
 
         self.ctes = None
@@ -439,7 +454,7 @@ class SQLCompiler(Compiled):
         ) and statement._returning:
             self.returning = statement._returning
 
-        if self.positional and dialect.paramstyle == 'numeric':
+        if self.positional and self._numeric_binds:
             self._apply_numbered_params()
 
     @property
@@ -492,7 +507,8 @@ class SQLCompiler(Compiled):
         return dict(
             (key, value) for key, value in
             ((self.bind_names[bindparam],
-              bindparam.type._cached_bind_processor(self.dialect))
+              bindparam.type._cached_bind_processor(self.dialect)
+              )
              for bindparam in self.bind_names)
             if value is not None
         )
@@ -1238,7 +1254,8 @@ class SQLCompiler(Compiled):
 
         self.binds[bindparam.key] = self.binds[name] = bindparam
 
-        return self.bindparam_string(name, **kwargs)
+        return self.bindparam_string(
+            name, expanding=bindparam.expanding, **kwargs)
 
     def render_literal_bindparam(self, bindparam, **kw):
         value = bindparam.effective_value
@@ -1300,13 +1317,18 @@ class SQLCompiler(Compiled):
         self.anon_map[derived] = anonymous_counter + 1
         return derived + "_" + str(anonymous_counter)
 
-    def bindparam_string(self, name, positional_names=None, **kw):
+    def bindparam_string(
+            self, name, positional_names=None, expanding=False, **kw):
         if self.positional:
             if positional_names is not None:
                 positional_names.append(name)
             else:
                 self.positiontup.append(name)
-        return self.bindtemplate % {'name': name}
+        if expanding:
+            self.contains_expanding_parameters = True
+            return "([EXPANDING_%s])" % name
+        else:
+            return self.bindtemplate % {'name': name}
 
     def visit_cte(self, cte, asfrom=False, ashint=False,
                   fromhints=None,
index d409ebacce49eb38bebdfd721ab99792496a221a..4ba53ef758086cbd3029af833145f8e4ca867f1c 100644 (file)
@@ -127,10 +127,18 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw):
         return _boolean_compare(expr, op, seq_or_selectable,
                                 negate=negate_op, **kw)
     elif isinstance(seq_or_selectable, ClauseElement):
-        raise exc.InvalidRequestError(
-            'in_() accepts'
-            ' either a list of expressions '
-            'or a selectable: %r' % seq_or_selectable)
+        if isinstance(seq_or_selectable, BindParameter) and \
+                seq_or_selectable.expanding:
+            return _boolean_compare(
+                expr, op,
+                seq_or_selectable,
+                negate=negate_op)
+        else:
+            raise exc.InvalidRequestError(
+                'in_() accepts'
+                ' either a list of expressions, '
+                'a selectable, or an "expanding" bound parameter: %r'
+                % seq_or_selectable)
 
     # Handle non selectable arguments as sequences
     args = []
@@ -139,8 +147,8 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw):
             if not isinstance(o, operators.ColumnOperators):
                 raise exc.InvalidRequestError(
                     'in_() accepts'
-                    ' either a list of expressions '
-                    'or a selectable: %r' % o)
+                    ' either a list of expressions, '
+                    'a selectable, or an "expanding" bound parameter: %r' % o)
         elif o is None:
             o = Null()
         else:
index 001c3d042dd080a7be33379265b6babb6aaaa3d2..414e3f47780b11e6d7911a985bf11db4e28122b6 100644 (file)
@@ -867,6 +867,7 @@ class BindParameter(ColumnElement):
     def __init__(self, key, value=NO_ARG, type_=None,
                  unique=False, required=NO_ARG,
                  quote=None, callable_=None,
+                 expanding=False,
                  isoutparam=False,
                  _compared_to_operator=None,
                  _compared_to_type=None):
@@ -1052,6 +1053,23 @@ class BindParameter(ColumnElement):
           "OUT" parameter.  This applies to backends such as Oracle which
           support OUT parameters.
 
+        :param expanding:
+          if True, this parameter will be treated as an "expanding" parameter
+          at execution time; the parameter value is expected to be a sequence,
+          rather than a scalar value, and the string SQL statement will
+          be transformed on a per-execution basis to accomodate the sequence
+          with a variable number of parameter slots passed to the DBAPI.
+          This is to allow statement caching to be used in conjunction with
+          an IN clause.
+
+          .. note:: The "expanding" feature does not support "executemany"-
+             style parameter sets, nor does it support empty IN expressions.
+
+          .. note:: The "expanding" feature should be considered as
+             **experimental** within the 1.2 series.
+
+          .. versionadded:: 1.2
+
         .. seealso::
 
             :ref:`coretutorial_bind_param`
@@ -1093,6 +1111,8 @@ class BindParameter(ColumnElement):
         self.callable = callable_
         self.isoutparam = isoutparam
         self.required = required
+        self.expanding = expanding
+
         if type_ is None:
             if _compared_to_type is not None:
                 self.type = \
index d38a6915956828c03c768b84cf84f1d0aee9a668..95aef0e1755c0b8078b4730b8ca9904a6de89100 100644 (file)
@@ -219,6 +219,14 @@ class SuiteRequirements(Requirements):
             "%(database)s %(does_support)s 'returning'"
         )
 
+    @property
+    def tuple_in(self):
+        """Target platform supports the syntax
+        "(x, y) IN ((x1, y1), (x2, y2), ...)"
+        """
+
+        return exclusions.closed()
+
     @property
     def duplicate_names_in_cursor_description(self):
         """target platform supports a SELECT statement that has
index e7de356b8590ff17e2649f14e30feb83e71ee7da..4086a4c24d8e0b68f2db342e988cec7e5007620b 100644 (file)
@@ -2,7 +2,7 @@ from .. import fixtures, config
 from ..assertions import eq_
 
 from sqlalchemy import util
-from sqlalchemy import Integer, String, select, func, bindparam, union
+from sqlalchemy import Integer, String, select, func, bindparam, union, tuple_
 from sqlalchemy import testing
 
 from ..schema import Table, Column
@@ -310,3 +310,57 @@ class CompoundSelectTest(fixtures.TablesTest):
             u1.order_by(u1.c.id),
             [(2, 2, 3), (3, 3, 4)]
         )
+
+
+class ExpandingBoundInTest(fixtures.TablesTest):
+    __backend__ = True
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table("some_table", metadata,
+              Column('id', Integer, primary_key=True),
+              Column('x', Integer),
+              Column('y', Integer))
+
+    @classmethod
+    def insert_data(cls):
+        config.db.execute(
+            cls.tables.some_table.insert(),
+            [
+                {"id": 1, "x": 1, "y": 2},
+                {"id": 2, "x": 2, "y": 3},
+                {"id": 3, "x": 3, "y": 4},
+                {"id": 4, "x": 4, "y": 5},
+            ]
+        )
+
+    def _assert_result(self, select, result, params=()):
+        eq_(
+            config.db.execute(select, params).fetchall(),
+            result
+        )
+
+    def test_bound_in_scalar(self):
+        table = self.tables.some_table
+
+        stmt = select([table.c.id]).where(
+            table.c.x.in_(bindparam('q', expanding=True)))
+
+        self._assert_result(
+            stmt,
+            [(2, ), (3, ), (4, )],
+            params={"q": [2, 3, 4]},
+        )
+
+    @testing.requires.tuple_in
+    def test_bound_in_two_tuple(self):
+        table = self.tables.some_table
+
+        stmt = select([table.c.id]).where(
+            tuple_(table.c.x, table.c.y).in_(bindparam('q', expanding=True)))
+
+        self._assert_result(
+            stmt,
+            [(2, ), (3, ), (4, )],
+            params={"q": [(2, 3), (3, 4), (4, 5)]},
+        )
index ea940d16845db7f9d2a78e20917ad29913771cbe..63745a1137f782aa808cdbe232ed336286542252 100644 (file)
@@ -209,6 +209,10 @@ class DefaultRequirements(SuiteRequirements):
 
         return skip_if(["oracle", "mssql"], "not supported by database/driver")
 
+    @property
+    def tuple_in(self):
+        return only_on(["mysql", "postgresql"])
+
     @property
     def independent_cursors(self):
         """Target must support simultaneous, independent database cursors
index 8b19b8931424995a4b44ff48853d437e85b349ff..05893d748ca53ec00bc846324d5fb601b1116ece 100644 (file)
@@ -2174,6 +2174,18 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
             "myothertable.otherid, myothertable.othername FROM myothertable)"
         )
 
+    def test_expanding_parameter(self):
+        self.assert_compile(
+            tuple_(table1.c.myid, table1.c.name).in_(
+                bindparam('foo', expanding=True)),
+            "(mytable.myid, mytable.name) IN ([EXPANDING_foo])"
+        )
+
+        self.assert_compile(
+            table1.c.myid.in_(bindparam('foo', expanding=True)),
+            "mytable.myid IN ([EXPANDING_foo])"
+        )
+
     def test_cast(self):
         tbl = table('casttest',
                     column('id', Integer),
index 217af4337701efd630033f3888b6d26f68c558c5..ac05d3a8114b7bb6c7c198905acfe90b52adc50c 100644 (file)
@@ -184,7 +184,7 @@ class DefaultColumnComparatorTest(fixtures.TestBase):
         foo = ClauseList()
         assert_raises_message(
             exc.InvalidRequestError,
-            r"in_\(\) accepts either a list of expressions or a selectable:",
+            r"in_\(\) accepts either a list of expressions, a selectable",
             left.in_, [foo]
         )
 
@@ -193,7 +193,7 @@ class DefaultColumnComparatorTest(fixtures.TestBase):
         right = column('right')
         assert_raises_message(
             exc.InvalidRequestError,
-            r"in_\(\) accepts either a list of expressions or a selectable:",
+            r"in_\(\) accepts either a list of expressions, a selectable",
             left.in_, right
         )
 
@@ -210,7 +210,7 @@ class DefaultColumnComparatorTest(fixtures.TestBase):
         right = column('right', HasGetitem)
         assert_raises_message(
             exc.InvalidRequestError,
-            r"in_\(\) accepts either a list of expressions or a selectable:",
+            r"in_\(\) accepts either a list of expressions, a selectable",
             left.in_, right
         )
 
index d90cb04764d7300f6b33c1b43d9fd71bbb1b6a1a..28300855f863bd1577636feb59a554555e6e1c9f 100644 (file)
@@ -6,7 +6,7 @@ from sqlalchemy import (
     exc, sql, func, select, String, Integer, MetaData, and_, ForeignKey,
     union, intersect, except_, union_all, VARCHAR, INT, text,
     bindparam, literal, not_, literal_column, desc, asc,
-    TypeDecorator, or_, cast)
+    TypeDecorator, or_, cast, tuple_)
 from sqlalchemy.engine import default
 from sqlalchemy.testing.schema import Table, Column
 
@@ -405,7 +405,6 @@ class QueryTest(fixtures.TestBase):
                     use_labels=labels),
                 [(3, 'a'), (2, 'b'), (1, None)])
 
-    @testing.emits_warning('.*empty sequence.*')
     def test_in_filtering(self):
         """test the behavior of the in_() function."""
 
@@ -431,6 +430,77 @@ class QueryTest(fixtures.TestBase):
         # Null values are not outside any set
         assert len(r) == 0
 
+    def test_expanding_in(self):
+        testing.db.execute(
+            users.insert(),
+            [
+                dict(user_id=7, user_name='jack'),
+                dict(user_id=8, user_name='fred'),
+                dict(user_id=9, user_name=None)
+            ]
+        )
+
+        with testing.db.connect() as conn:
+            stmt = select([users]).where(
+                users.c.user_name.in_(bindparam('uname', expanding=True))
+            ).order_by(users.c.user_id)
+
+            eq_(
+                conn.execute(stmt, {"uname": ['jack']}).fetchall(),
+                [(7, 'jack')]
+            )
+
+            eq_(
+                conn.execute(stmt, {"uname": ['jack', 'fred']}).fetchall(),
+                [(7, 'jack'), (8, 'fred')]
+            )
+
+            assert_raises_message(
+                exc.StatementError,
+                "'expanding' parameters can't be used with an empty list",
+                conn.execute,
+                stmt, {"uname": []}
+            )
+
+            assert_raises_message(
+                exc.StatementError,
+                "'expanding' parameters can't be used with executemany()",
+                conn.execute,
+                users.update().where(
+                    users.c.user_name.in_(bindparam('uname', expanding=True))
+                ), [{"uname": ['fred']}, {"uname": ['ed']}]
+            )
+
+    @testing.requires.tuple_in
+    def test_expanding_in_composite(self):
+        testing.db.execute(
+            users.insert(),
+            [
+                dict(user_id=7, user_name='jack'),
+                dict(user_id=8, user_name='fred'),
+                dict(user_id=9, user_name=None)
+            ]
+        )
+
+        with testing.db.connect() as conn:
+            stmt = select([users]).where(
+                tuple_(
+                    users.c.user_id,
+                    users.c.user_name
+                ).in_(bindparam('uname', expanding=True))
+            ).order_by(users.c.user_id)
+
+            eq_(
+                conn.execute(stmt, {"uname": [(7, 'jack')]}).fetchall(),
+                [(7, 'jack')]
+            )
+
+            eq_(
+                conn.execute(stmt, {"uname": [(7, 'jack'), (8, 'fred')]}).fetchall(),
+                [(7, 'jack'), (8, 'fred')]
+            )
+
+
     @testing.fails_on('firebird', "uses sql-92 rules")
     @testing.fails_on('sybase', "uses sql-92 rules")
     @testing.fails_if(