]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- [feature] Added support for MSSQL INSERT,
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 13 Mar 2012 21:00:05 +0000 (14:00 -0700)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 13 Mar 2012 21:00:05 +0000 (14:00 -0700)
UPDATE, and DELETE table hints, using
new with_hint() method on UpdateBase.
[ticket:2430]

CHANGES
doc/build/core/expression_api.rst
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
test/dialect/test_mssql.py

diff --git a/CHANGES b/CHANGES
index 6c94f3161c15efc840710a1342bc3d3fd708f5d2..91262ea2660edd83621fc99ec0e292d57fdc11b4 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -181,6 +181,12 @@ CHANGES
     commit or rollback transaction with errors
     on engine.begin().
 
+- mssql
+  - [feature] Added support for MSSQL INSERT, 
+    UPDATE, and DELETE table hints, using
+    new with_hint() method on UpdateBase.
+    [ticket:2430]
+
 - mysql
   - [feature] Added support for the "isolation_level"
     parameter to all MySQL dialects.  Thanks
index 4cec26f98265731a6c3e03aaa9e8f8af6894ccd8..fefc8eb592781152e9dc57cb82338f3b2f783369 100644 (file)
@@ -220,11 +220,11 @@ Classes
    :show-inheritance:
 
 .. autoclass:: Update
-  :members: where, values
+  :members:
   :show-inheritance:
 
 .. autoclass:: UpdateBase
-  :members: params, bind, returning
+  :members:
   :show-inheritance:
 
 .. autoclass:: ValuesBase
index b73235875c0d503cc0f9590c4c06437c43074301..103b0a3e992f5c3bd4b1c1bf1c8317b1449422a3 100644 (file)
@@ -791,6 +791,9 @@ class MSSQLCompiler(compiler.SQLCompiler):
     def get_from_hint_text(self, table, text):
         return text
 
+    def get_crud_hint_text(self, table, text):
+        return text
+
     def limit_clause(self, select):
         # Limit in mssql is after the select keyword
         return ""
index d0dd28e70ca945ac0b1af6238b6c92aa88a2c82a..d71acbc5936b09cf3742b2aed26b746e75b8c3c7 100644 (file)
@@ -1348,7 +1348,8 @@ class MySQLCompiler(compiler.SQLCompiler):
         return ', '.join(t._compiler_dispatch(self, asfrom=True, **kw) 
                     for t in [from_table] + list(extra_froms))
 
-    def update_from_clause(self, update_stmt, from_table, extra_froms, **kw):
+    def update_from_clause(self, update_stmt, from_table, 
+                                extra_froms, from_hints, **kw):
         return None
 
 
index 6f010ed54c7b53ec2cd1744fc722547e983c9e7d..c5c6f9ec8b89d95432c053b839e3358c8a1d87df 100644 (file)
@@ -855,6 +855,9 @@ class SQLCompiler(engine.Compiled):
     def get_from_hint_text(self, table, text):
         return None
 
+    def get_crud_hint_text(self, table, text):
+        return None
+
     def visit_select(self, select, asfrom=False, parens=True, 
                             iswrapper=False, fromhints=None, 
                             compound_index=1, **kwargs):
@@ -1048,12 +1051,26 @@ class SQLCompiler(engine.Compiled):
 
         text = "INSERT"
 
+
         prefixes = [self.process(x) for x in insert_stmt._prefixes]
         if prefixes:
             text += " " + " ".join(prefixes)
 
         text += " INTO " + preparer.format_table(insert_stmt.table)
 
+        if insert_stmt._hints:
+            dialect_hints = dict([
+                (table, hint_text)
+                for (table, dialect), hint_text in 
+                insert_stmt._hints.items()
+                if dialect in ('*', self.dialect.name)
+            ])
+            if insert_stmt.table in dialect_hints:
+                text += " " + self.get_crud_hint_text(
+                                    insert_stmt.table, 
+                                    dialect_hints[insert_stmt.table]
+                                )
+
         if colparams or not supports_default_values:
             text += " (%s)" % ', '.join([preparer.format_column(c[0])
                        for c in colparams])
@@ -1085,21 +1102,25 @@ class SQLCompiler(engine.Compiled):
                                             extra_froms, **kw):
         """Provide a hook to override the initial table clause
         in an UPDATE statement.
-        
+
         MySQL overrides this.
 
         """
         return self.preparer.format_table(from_table)
 
-    def update_from_clause(self, update_stmt, from_table, extra_froms, **kw):
+    def update_from_clause(self, update_stmt, 
+                                from_table, extra_froms, 
+                                from_hints,
+                                **kw):
         """Provide a hook to override the generation of an 
         UPDATE..FROM clause.
-        
+
         MySQL overrides this.
 
         """
         return "FROM " + ', '.join(
-                    t._compiler_dispatch(self, asfrom=True, **kw) 
+                    t._compiler_dispatch(self, asfrom=True, 
+                                    fromhints=from_hints, **kw) 
                     for t in extra_froms)
 
     def visit_update(self, update_stmt, **kw):
@@ -1116,6 +1137,21 @@ class SQLCompiler(engine.Compiled):
                                         update_stmt.table, 
                                         extra_froms, **kw)
 
+        if update_stmt._hints:
+            dialect_hints = dict([
+                (table, hint_text)
+                for (table, dialect), hint_text in 
+                update_stmt._hints.items()
+                if dialect in ('*', self.dialect.name)
+            ])
+            if update_stmt.table in dialect_hints:
+                text += " " + self.get_crud_hint_text(
+                                    update_stmt.table, 
+                                    dialect_hints[update_stmt.table]
+                                )
+        else:
+            dialect_hints = None
+
         text += ' SET '
         if extra_froms and self.render_table_with_column_in_update_from:
             text += ', '.join(
@@ -1138,7 +1174,8 @@ class SQLCompiler(engine.Compiled):
             extra_from_text = self.update_from_clause(
                                         update_stmt, 
                                         update_stmt.table, 
-                                        extra_froms, **kw)
+                                        extra_froms, 
+                                        dialect_hints, **kw)
             if extra_from_text:
                 text += " " + extra_from_text
 
@@ -1377,6 +1414,21 @@ class SQLCompiler(engine.Compiled):
 
         text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table)
 
+        if delete_stmt._hints:
+            dialect_hints = dict([
+                (table, hint_text)
+                for (table, dialect), hint_text in 
+                delete_stmt._hints.items()
+                if dialect in ('*', self.dialect.name)
+            ])
+            if delete_stmt.table in dialect_hints:
+                text += " " + self.get_crud_hint_text(
+                                    delete_stmt.table, 
+                                    dialect_hints[delete_stmt.table]
+                                )
+        else:
+            dialect_hints = None
+
         if delete_stmt._returning:
             self.returning = delete_stmt._returning
             if self.returning_precedes_values:
index 50b7375bfc8535e013e51d9740f238e197195f4c..aa67f44fa0827f615e22f1a5f7ad64ff54ef7575 100644 (file)
@@ -4833,7 +4833,7 @@ class Select(_SelectBase):
         The text of the hint is rendered in the appropriate
         location for the database backend in use, relative
         to the given :class:`.Table` or :class:`.Alias` passed as the
-        *selectable* argument. The dialect implementation
+        ``selectable`` argument. The dialect implementation
         typically uses Python string substitution syntax
         with the token ``%(name)s`` to render the name of
         the table or alias. E.g. when using Oracle, the
@@ -5319,6 +5319,7 @@ class UpdateBase(Executable, ClauseElement):
     _execution_options = \
         Executable._execution_options.union({'autocommit': True})
     kwargs = util.immutabledict()
+    _hints = util.immutabledict()
 
     def _process_colparams(self, parameters):
         if isinstance(parameters, (list, tuple)):
@@ -5399,6 +5400,45 @@ class UpdateBase(Executable, ClauseElement):
         """
         self._returning = cols
 
+    @_generative
+    def with_hint(self, text, selectable=None, dialect_name="*"):
+        """Add a table hint for a single table to this 
+        INSERT/UPDATE/DELETE statement.
+
+        .. note::
+
+         :meth:`.UpdateBase.with_hint` currently applies only to 
+         Microsoft SQL Server.  For MySQL INSERT hints, use
+         :meth:`.Insert.prefix_with`.   UPDATE/DELETE hints for 
+         MySQL will be added in a future release.
+         
+        The text of the hint is rendered in the appropriate
+        location for the database backend in use, relative
+        to the :class:`.Table` that is the subject of this
+        statement, or optionally to that of the given 
+        :class:`.Table` passed as the ``selectable`` argument.
+
+        The ``dialect_name`` option will limit the rendering of a particular
+        hint to a particular backend. Such as, to add a hint
+        that only takes effect for SQL Server::
+
+            mytable.insert().with_hint("WITH (PAGLOCK)", dialect_name="mssql")
+
+        New in 0.7.6.
+
+        :param text: Text of the hint.
+        :param selectable: optional :class:`.Table` that specifies
+         an element of the FROM clause within an UPDATE or DELETE
+         to be the subject of the hint - applies only to certain backends.
+        :param dialect_name: defaults to ``*``, if specified as the name
+         of a particular dialect, will apply these hints only when
+         that dialect is in use.
+         """
+        if selectable is None:
+            selectable = self.table
+
+        self._hints = self._hints.union({(selectable, dialect_name):text})
+
 class ValuesBase(UpdateBase):
     """Supplies support for :meth:`.ValuesBase.values` to INSERT and UPDATE constructs."""
 
index 94609d9534f81cadcbb737ce4e1712332c5b5ded..dddc6333d3c7d4ba75d1466381890c8716fec2e6 100644 (file)
@@ -63,6 +63,96 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
                             'n WHERE sometable.somecolumn = '
                             ':somecolumn_1', dict(somecolumn=10))
 
+    def test_insert_hint(self):
+        t = table('sometable', column('somecolumn'))
+        for targ in (None, t):
+            for darg in ("*", "mssql"):
+                self.assert_compile(
+                    t.insert().
+                        values(somecolumn="x").
+                        with_hint("WITH (PAGLOCK)",
+                            selectable=targ,
+                            dialect_name=darg),
+                    "INSERT INTO sometable WITH (PAGLOCK) "
+                    "(somecolumn) VALUES (:somecolumn)"
+                )
+
+    def test_update_hint(self):
+        t = table('sometable', column('somecolumn'))
+        for targ in (None, t):
+            for darg in ("*", "mssql"):
+                self.assert_compile(
+                    t.update().where(t.c.somecolumn=="q").
+                            values(somecolumn="x").
+                            with_hint("WITH (PAGLOCK)", 
+                                    selectable=targ, 
+                                    dialect_name=darg),
+                    "UPDATE sometable WITH (PAGLOCK) "
+                    "SET somecolumn=:somecolumn "
+                    "WHERE sometable.somecolumn = :somecolumn_1"
+                )
+
+    def test_update_exclude_hint(self):
+        t = table('sometable', column('somecolumn'))
+        self.assert_compile(
+            t.update().where(t.c.somecolumn=="q").
+                values(somecolumn="x").
+                with_hint("XYZ", "mysql"),
+            "UPDATE sometable SET somecolumn=:somecolumn "
+            "WHERE sometable.somecolumn = :somecolumn_1"
+        )
+
+    def test_delete_hint(self):
+        t = table('sometable', column('somecolumn'))
+        for targ in (None, t):
+            for darg in ("*", "mssql"):
+                self.assert_compile(
+                    t.delete().where(t.c.somecolumn=="q").
+                            with_hint("WITH (PAGLOCK)", 
+                                    selectable=targ, 
+                                    dialect_name=darg),
+                    "DELETE FROM sometable WITH (PAGLOCK) "
+                    "WHERE sometable.somecolumn = :somecolumn_1"
+                )
+
+    def test_delete_exclude_hint(self):
+        t = table('sometable', column('somecolumn'))
+        self.assert_compile(
+            t.delete().\
+                where(t.c.somecolumn=="q").\
+                with_hint("XYZ", dialect_name="mysql"),
+            "DELETE FROM sometable WHERE "
+            "sometable.somecolumn = :somecolumn_1"
+        )
+
+    def test_update_from_hint(self):
+        t = table('sometable', column('somecolumn'))
+        t2 = table('othertable', column('somecolumn'))
+        for darg in ("*", "mssql"):
+            self.assert_compile(
+                t.update().where(t.c.somecolumn==t2.c.somecolumn).
+                        values(somecolumn="x").
+                        with_hint("WITH (PAGLOCK)", 
+                                selectable=t2, 
+                                dialect_name=darg),
+                "UPDATE sometable SET somecolumn=:somecolumn "
+                "FROM othertable WITH (PAGLOCK) "
+                "WHERE sometable.somecolumn = othertable.somecolumn"
+            )
+
+    # TODO: not supported yet.
+    #def test_delete_from_hint(self):
+    #    t = table('sometable', column('somecolumn'))
+    #    t2 = table('othertable', column('somecolumn'))
+    #    for darg in ("*", "mssql"):
+    #        self.assert_compile(
+    #            t.delete().where(t.c.somecolumn==t2.c.somecolumn).
+    #                    with_hint("WITH (PAGLOCK)", 
+    #                            selectable=t2, 
+    #                            dialect_name=darg),
+    #            ""
+    #        )
+
     # TODO: should this be for *all* MS-SQL dialects ?
     def test_mxodbc_binds(self):
         """mxodbc uses MS-SQL native binds, which aren't allowed in