From: Mike Bayer Date: Sat, 27 Mar 2010 21:18:53 +0000 (-0400) Subject: - Added with_hint() method to Query() construct. This calls X-Git-Tag: rel_0_6beta3~8 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=36047e9bb28501477b1403059087cccc120be2b6;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Added with_hint() method to Query() construct. This calls directly down to select().with_hint() and also accepts entities as well as tables and aliases. See with_hint() in the SQL section below. [ticket:921] - Added with_hint() method to select() construct. Specify a table/alias, hint text, and optional dialect name, and "hints" will be rendered in the appropriate place in the statement. Works for Oracle, Sybase, MySQL. [ticket:921] --- diff --git a/CHANGES b/CHANGES index 7ab07aaff9..8022d0f3a3 100644 --- a/CHANGES +++ b/CHANGES @@ -33,6 +33,11 @@ CHANGES "select". The old values of True/ False/None still retain their usual meanings and will remain as synonyms for the foreseeable future. + + - Added with_hint() method to Query() construct. This calls + directly down to select().with_hint() and also accepts + entities as well as tables and aliases. See with_hint() in the + SQL section below. [ticket:921] - Fixed bug in Query whereby calling q.join(prop).from_self(...). join(prop) would fail to render the second join outside the @@ -63,6 +68,11 @@ CHANGES in that regard. - sql + - Added with_hint() method to select() construct. Specify + a table/alias, hint text, and optional dialect name, and + "hints" will be rendered in the appropriate place in the + statement. Works for Oracle, Sybase, MySQL. [ticket:921] + - Fixed bug introduced in 0.6beta2 where column labels would render inside of column expressions already assigned a label. [ticket:1747] diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 873dfd16c7..f9bb482354 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1154,7 +1154,10 @@ class MySQLCompiler(compiler.SQLCompiler): def visit_match_op(self, binary, **kw): return "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (self.process(binary.left), self.process(binary.right)) - + + def get_from_hint_text(self, table, text): + return text + def visit_typeclause(self, typeclause): type_ = typeclause.type.dialect_impl(self.dialect) if isinstance(type_, sqltypes.Integer): @@ -1204,11 +1207,11 @@ class MySQLCompiler(compiler.SQLCompiler): # support can be added, preferably after dialects are # refactored to be version-sensitive. return ''.join( - (self.process(join.left, asfrom=True), + (self.process(join.left, asfrom=True, **kwargs), (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN "), - self.process(join.right, asfrom=True), + self.process(join.right, asfrom=True, **kwargs), " ON ", - self.process(join.onclause))) + self.process(join.onclause, **kwargs))) def for_update_clause(self, select): if select.for_update == 'read': diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 332fa805d1..475730988f 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -342,6 +342,11 @@ class OracleCompiler(compiler.SQLCompiler): def visit_match_op(self, binary, **kw): return "CONTAINS (%s, %s)" % (self.process(binary.left), self.process(binary.right)) + def get_select_hint_text(self, byfroms): + return " ".join( + "/*+ %s */" % text for table, text in byfroms.items() + ) + def function_argspec(self, fn, **kw): if len(fn.clauses) > 0: return compiler.SQLCompiler.function_argspec(self, fn, **kw) @@ -360,7 +365,9 @@ class OracleCompiler(compiler.SQLCompiler): if self.dialect.use_ansi: return compiler.SQLCompiler.visit_join(self, join, **kwargs) else: - return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True) + kwargs['asfrom'] = True + return self.process(join.left, **kwargs) + \ + ", " + self.process(join.right, **kwargs) def _get_nonansi_join_whereclause(self, froms): clauses = [] @@ -392,14 +399,18 @@ class OracleCompiler(compiler.SQLCompiler): def visit_sequence(self, seq): return self.dialect.identifier_preparer.format_sequence(seq) + ".nextval" - def visit_alias(self, alias, asfrom=False, **kwargs): + def visit_alias(self, alias, asfrom=False, ashint=False, **kwargs): """Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??""" - - if asfrom: + + if asfrom or ashint: alias_name = isinstance(alias.name, expression._generated_label) and \ self._truncated_identifier("alias", alias.name) or alias.name - - return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + self.preparer.format_alias(alias, alias_name) + + if ashint: + return alias_name + elif asfrom: + return self.process(alias.original, asfrom=asfrom, **kwargs) + \ + " " + self.preparer.format_alias(alias, alias_name) else: return self.process(alias.original, **kwargs) diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py index aaec7a504d..79e32b9681 100644 --- a/lib/sqlalchemy/dialects/sybase/base.py +++ b/lib/sqlalchemy/dialects/sybase/base.py @@ -277,6 +277,9 @@ class SybaseSQLCompiler(compiler.SQLCompiler): s += "START AT %s " % (select._offset+1,) return s + def get_from_hint_text(self, table, text): + return text + def limit_clause(self, select): # Limit in sybase is after the select keyword return "" diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 5b9169c2e3..e98ad8937c 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -84,6 +84,7 @@ class Query(object): _params = util.frozendict() _attributes = util.frozendict() _with_options = () + _with_hints = () def __init__(self, entities, session=None): self.session = session @@ -718,6 +719,21 @@ class Query(object): for opt in opts: opt.process_query(self) + @_generative() + def with_hint(self, selectable, text, dialect_name=None): + """Add an indexing hint for the given entity or selectable to + this :class:`Query`. + + Functionality is passed straight through to + :meth:`~sqlalchemy.sql.expression.Select.with_hint`, + with the addition that ``selectable`` can be a + :class:`Table`, :class:`Alias`, or ORM entity / mapped class + /etc. + """ + mapper, selectable, is_aliased_class = _entity_info(selectable) + + self._with_hints += ((selectable, text, dialect_name),) + @_generative() def execution_options(self, **kwargs): """ Set non-SQL options which take effect during execution. @@ -2053,7 +2069,10 @@ class Query(object): order_by=context.order_by, **self._select_args ) - + + for hint in self._with_hints: + inner = inner.with_hint(*hint) + if self._correlate: inner = inner.correlate(*self._correlate) @@ -2108,6 +2127,10 @@ class Query(object): order_by=context.order_by, **self._select_args ) + + for hint in self._with_hints: + statement = statement.with_hint(*hint) + if self._execution_options: statement = statement.execution_options(**self._execution_options) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 75b3f79f06..78c65771b7 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -628,13 +628,22 @@ class SQLCompiler(engine.Compiled): else: return self.bindtemplate % {'name':name} - def visit_alias(self, alias, asfrom=False, **kwargs): - if asfrom: + def visit_alias(self, alias, asfrom=False, ashint=False, fromhints=None, **kwargs): + if asfrom or ashint: alias_name = isinstance(alias.name, sql._generated_label) and \ self._truncated_identifier("alias", alias.name) or alias.name - - return self.process(alias.original, asfrom=True, **kwargs) + " AS " + \ + if ashint: + return self.preparer.format_alias(alias, alias_name) + elif asfrom: + ret = self.process(alias.original, asfrom=True, **kwargs) + " AS " + \ self.preparer.format_alias(alias, alias_name) + + if fromhints and alias in fromhints: + hinttext = self.get_from_hint_text(alias, fromhints[alias]) + if hinttext: + ret += " " + hinttext + + return ret else: return self.process(alias.original, **kwargs) @@ -661,8 +670,15 @@ class SQLCompiler(engine.Compiled): else: return column + def get_select_hint_text(self, byfroms): + return None + + def get_from_hint_text(self, table, text): + return None + def visit_select(self, select, asfrom=False, parens=True, - iswrapper=False, compound_index=1, **kwargs): + iswrapper=False, fromhints=None, + compound_index=1, **kwargs): entry = self.stack and self.stack[-1] or {} @@ -697,6 +713,18 @@ class SQLCompiler(engine.Compiled): ] text = "SELECT " # we're off to a good start ! + + if select._hints: + byfrom = dict([ + (from_, hinttext % {'name':self.process(from_, ashint=True)}) + for (from_, dialect), hinttext in + select._hints.iteritems() + if dialect in ('*', self.dialect.name) + ]) + hint_text = self.get_select_hint_text(byfrom) + if hint_text: + text += hint_text + " " + if select._prefixes: text += " ".join(self.process(x, **kwargs) for x in select._prefixes) + " " text += self.get_select_precolumns(select) @@ -704,7 +732,16 @@ class SQLCompiler(engine.Compiled): if froms: text += " \nFROM " - text += ', '.join(self.process(f, asfrom=True, **kwargs) for f in froms) + + if select._hints: + text += ', '.join([self.process(f, + asfrom=True, fromhints=byfrom, + **kwargs) + for f in froms]) + else: + text += ', '.join([self.process(f, + asfrom=True, **kwargs) + for f in froms]) else: text += self.default_from() @@ -767,20 +804,26 @@ class SQLCompiler(engine.Compiled): text += " OFFSET " + str(select._offset) return text - def visit_table(self, table, asfrom=False, **kwargs): - if asfrom: + def visit_table(self, table, asfrom=False, ashint=False, fromhints=None, **kwargs): + if asfrom or ashint: if getattr(table, "schema", None): - return self.preparer.quote_schema(table.schema, table.quote_schema) + \ + ret = self.preparer.quote_schema(table.schema, table.quote_schema) + \ "." + self.preparer.quote(table.name, table.quote) else: - return self.preparer.quote(table.name, table.quote) + ret = self.preparer.quote(table.name, table.quote) + if fromhints and table in fromhints: + hinttext = self.get_from_hint_text(table, fromhints[table]) + if hinttext: + ret += " " + hinttext + return ret else: return "" def visit_join(self, join, asfrom=False, **kwargs): - return (self.process(join.left, asfrom=True) + \ + return (self.process(join.left, asfrom=True, **kwargs) + \ (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \ - self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause)) + self.process(join.right, asfrom=True, **kwargs) + " ON " + \ + self.process(join.onclause, **kwargs)) def visit_sequence(self, seq): return None diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 1e02ba96a7..3aaa06fd6e 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -3557,6 +3557,7 @@ class Select(_SelectBaseMixin, FromClause): __visit_name__ = 'select' _prefixes = () + _hints = util.frozendict() def __init__(self, columns, @@ -3659,7 +3660,34 @@ class Select(_SelectBaseMixin, FromClause): """Return the displayed list of FromClause elements.""" return self._get_display_froms() - + + @_generative + def with_hint(self, selectable, text, dialect_name=None): + """Add an indexing hint for the given selectable to this :class:`Select`. + + The text of the hint is written specific to a specific backend, and + typically uses Python string substitution syntax to render the name + of the table or alias, such as for Oracle:: + + select([mytable]).with_hint(mytable, "+ index(%(name)s ix_mytable)") + + Would render SQL as:: + + select /*+ index(mytable ix_mytable) */ ... from mytable + + The ``dialect_name`` option will limit the rendering of a particular hint + to a particular backend. Such as, to add hints for both Oracle and + Sybase simultaneously:: + + select([mytable]).\ + with_hint(mytable, "+ index(%(name)s ix_mytable)", 'oracle').\ + with_hint(mytable, "WITH INDEX ix_mytable", 'sybase') + + """ + if not dialect_name: + dialect_name = '*' + self._hints = self._hints.union({(selectable, dialect_name):text}) + @property def type(self): raise exc.InvalidRequestError("Select objects don't have a type. " diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 67ff4e4722..89b66fd868 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -1199,6 +1199,39 @@ class YieldTest(QueryTest): except StopIteration: pass +class HintsTest(QueryTest, AssertsCompiledSQL): + def test_hints(self): + from sqlalchemy.dialects import mysql + dialect = mysql.dialect() + + sess = create_session() + + self.assert_compile( + sess.query(User).with_hint(User, 'USE INDEX (col1_index,col2_index)'), + "SELECT users.id AS users_id, users.name AS users_name " + "FROM users USE INDEX (col1_index,col2_index)", + dialect=dialect + ) + + self.assert_compile( + sess.query(User).with_hint(User, 'WITH INDEX col1_index', 'sybase'), + "SELECT users.id AS users_id, users.name AS users_name " + "FROM users", + dialect=dialect + ) + + ualias = aliased(User) + self.assert_compile( + sess.query(User, ualias).with_hint(ualias, 'USE INDEX (col1_index,col2_index)'). + join((ualias, ualias.id > User.id)), + "SELECT users.id AS users_id, users.name AS users_name, " + "users_1.id AS users_1_id, users_1.name AS users_1_name " + "FROM users INNER JOIN users AS users_1 USE INDEX (col1_index,col2_index) " + "ON users.id < users_1.id", + dialect=dialect + ) + + class TextTest(QueryTest): def test_fulltext(self): assert [User(id=7), User(id=8), User(id=9),User(id=10)] == create_session().query(User).from_statement("select * from users order by id").all() diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index fb32c29bb4..a5b97be38d 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -1760,15 +1760,128 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A else: self.assert_compile(s1, "SELECT %s FROM (SELECT %s FROM mytable)" % (expr,expr)) + def test_hints(self): + s = select([table1.c.myid]).with_hint(table1, "test hint %(name)s") + + s2 = select([table1.c.myid]).\ + with_hint(table1, "index(%(name)s idx)", 'oracle').\ + with_hint(table1, "WITH HINT INDEX idx", 'sybase') + + a1 = table1.alias() + s3 = select([a1.c.myid]).with_hint(a1, "index(%(name)s hint)") + + subs4 = select([ + table1, table2 + ]).select_from(table1.join(table2, table1.c.myid==table2.c.otherid)).\ + with_hint(table1, 'hint1') + + s4 = select([table3]).select_from( + table3.join( + subs4, + subs4.c.othername==table3.c.otherstuff + ) + ).\ + with_hint(table3, 'hint3') + + subs5 = select([ + table1, table2 + ]).select_from(table1.join(table2, table1.c.myid==table2.c.otherid)) + s5 = select([table3]).select_from( + table3.join( + subs5, + subs5.c.othername==table3.c.otherstuff + ) + ).\ + with_hint(table3, 'hint3').\ + with_hint(table1, 'hint1') + + t1 = table('QuotedName', column('col1')) + s6 = select([t1.c.col1]).where(t1.c.col1>10).with_hint(t1, '%(name)s idx1') + a2 = t1.alias('SomeName') + s7 = select([a2.c.col1]).where(a2.c.col1>10).with_hint(a2, '%(name)s idx1') + + mysql_d, oracle_d, sybase_d = \ + mysql.dialect(), \ + oracle.dialect(), \ + sybase.dialect() + + for stmt, dialect, expected in [ + (s, mysql_d, + "SELECT mytable.myid FROM mytable test hint mytable"), + (s, oracle_d, + "SELECT /*+ test hint mytable */ mytable.myid FROM mytable"), + (s, sybase_d, + "SELECT mytable.myid FROM mytable test hint mytable"), + (s2, mysql_d, + "SELECT mytable.myid FROM mytable"), + (s2, oracle_d, + "SELECT /*+ index(mytable idx) */ mytable.myid FROM mytable"), + (s2, sybase_d, + "SELECT mytable.myid FROM mytable WITH HINT INDEX idx"), + (s3, mysql_d, + "SELECT mytable_1.myid FROM mytable AS mytable_1 " + "index(mytable_1 hint)"), + (s3, oracle_d, + "SELECT /*+ index(mytable_1 hint) */ mytable_1.myid FROM " + "mytable mytable_1"), + (s3, sybase_d, + "SELECT mytable_1.myid FROM mytable AS mytable_1 " + "index(mytable_1 hint)"), + (s4, mysql_d, + "SELECT thirdtable.userid, thirdtable.otherstuff FROM thirdtable " + "hint3 INNER JOIN (SELECT mytable.myid, mytable.name, " + "mytable.description, myothertable.otherid, " + "myothertable.othername FROM mytable hint1 INNER " + "JOIN myothertable ON mytable.myid = myothertable.otherid) " + "ON othername = thirdtable.otherstuff"), + (s4, sybase_d, + "SELECT thirdtable.userid, thirdtable.otherstuff FROM thirdtable " + "hint3 JOIN (SELECT mytable.myid, mytable.name, " + "mytable.description, myothertable.otherid, " + "myothertable.othername FROM mytable hint1 " + "JOIN myothertable ON mytable.myid = myothertable.otherid) " + "ON othername = thirdtable.otherstuff"), + (s4, oracle_d, + "SELECT /*+ hint3 */ thirdtable.userid, thirdtable.otherstuff " + "FROM thirdtable JOIN (SELECT /*+ hint1 */ mytable.myid," + " mytable.name, mytable.description, myothertable.otherid," + " myothertable.othername FROM mytable JOIN myothertable ON" + " mytable.myid = myothertable.otherid) ON othername =" + " thirdtable.otherstuff"), + (s5, oracle_d, + "SELECT /*+ hint3 */ /*+ hint1 */ thirdtable.userid, " + "thirdtable.otherstuff " + "FROM thirdtable JOIN (SELECT mytable.myid," + " mytable.name, mytable.description, myothertable.otherid," + " myothertable.othername FROM mytable JOIN myothertable ON" + " mytable.myid = myothertable.otherid) ON othername =" + " thirdtable.otherstuff"), + (s6, oracle_d, + """SELECT /*+ "QuotedName" idx1 */ "QuotedName".col1 """ + """FROM "QuotedName" WHERE "QuotedName".col1 > :col1_1"""), + (s7, oracle_d, + """SELECT /*+ SomeName idx1 */ "SomeName".col1 FROM """ + """"QuotedName" "SomeName" WHERE "SomeName".col1 > :col1_1"""), + ]: + self.assert_compile( + stmt, + expected, + dialect=dialect + ) + class CRUDTest(TestBase, AssertsCompiledSQL): def test_insert(self): # generic insert, will create bind params for all columns - self.assert_compile(insert(table1), "INSERT INTO mytable (myid, name, description) VALUES (:myid, :name, :description)") + self.assert_compile(insert(table1), + "INSERT INTO mytable (myid, name, description) " + "VALUES (:myid, :name, :description)") # insert with user-supplied bind params for specific columns, # cols provided literally self.assert_compile( - insert(table1, {table1.c.myid : bindparam('userid'), table1.c.name : bindparam('username')}), + insert(table1, { + table1.c.myid : bindparam('userid'), + table1.c.name : bindparam('username')}), "INSERT INTO mytable (myid, name) VALUES (:userid, :username)") # insert with user-supplied bind params for specific columns, cols @@ -1786,33 +1899,79 @@ class CRUDTest(TestBase, AssertsCompiledSQL): ) self.assert_compile( - insert(table1, values={table1.c.myid : bindparam('userid')}).values({table1.c.name : bindparam('username')}), + insert(table1, values={ + table1.c.myid : bindparam('userid') + }).values({table1.c.name : bindparam('username')}), "INSERT INTO mytable (myid, name) VALUES (:userid, :username)" ) - self.assert_compile(insert(table1, values=dict(myid=func.lala())), "INSERT INTO mytable (myid) VALUES (lala())") + self.assert_compile( + insert(table1, values=dict(myid=func.lala())), + "INSERT INTO mytable (myid) VALUES (lala())") def test_inline_insert(self): metadata = MetaData() table = Table('sometable', metadata, Column('id', Integer, primary_key=True), Column('foo', Integer, default=func.foobar())) - self.assert_compile(table.insert(values={}, inline=True), "INSERT INTO sometable (foo) VALUES (foobar())") - self.assert_compile(table.insert(inline=True), "INSERT INTO sometable (foo) VALUES (foobar())", params={}) + self.assert_compile( + table.insert(values={}, inline=True), + "INSERT INTO sometable (foo) VALUES (foobar())") + self.assert_compile( + table.insert(inline=True), + "INSERT INTO sometable (foo) VALUES (foobar())", params={}) def test_update(self): - self.assert_compile(update(table1, table1.c.myid == 7), "UPDATE mytable SET name=:name WHERE mytable.myid = :myid_1", params = {table1.c.name:'fred'}) - self.assert_compile(table1.update().where(table1.c.myid==7).values({table1.c.myid:5}), "UPDATE mytable SET myid=:myid WHERE mytable.myid = :myid_1", checkparams={'myid':5, 'myid_1':7}) - self.assert_compile(update(table1, table1.c.myid == 7), "UPDATE mytable SET name=:name WHERE mytable.myid = :myid_1", params = {'name':'fred'}) - self.assert_compile(update(table1, values = {table1.c.name : table1.c.myid}), "UPDATE mytable SET name=mytable.myid") - self.assert_compile(update(table1, whereclause = table1.c.name == bindparam('crit'), values = {table1.c.name : 'hi'}), "UPDATE mytable SET name=:name WHERE mytable.name = :crit", params = {'crit' : 'notthere'}, checkparams={'crit':'notthere', 'name':'hi'}) - self.assert_compile(update(table1, table1.c.myid == 12, values = {table1.c.name : table1.c.myid}), "UPDATE mytable SET name=mytable.myid, description=:description WHERE mytable.myid = :myid_1", params = {'description':'test'}, checkparams={'description':'test', 'myid_1':12}) - self.assert_compile(update(table1, table1.c.myid == 12, values = {table1.c.myid : 9}), "UPDATE mytable SET myid=:myid, description=:description WHERE mytable.myid = :myid_1", params = {'myid_1': 12, 'myid': 9, 'description': 'test'}) - self.assert_compile(update(table1, table1.c.myid ==12), "UPDATE mytable SET myid=:myid WHERE mytable.myid = :myid_1", params={'myid':18}, checkparams={'myid':18, 'myid_1':12}) + self.assert_compile( + update(table1, table1.c.myid == 7), + "UPDATE mytable SET name=:name WHERE mytable.myid = :myid_1", + params = {table1.c.name:'fred'}) + self.assert_compile( + table1.update().where(table1.c.myid==7). + values({table1.c.myid:5}), + "UPDATE mytable SET myid=:myid WHERE mytable.myid = :myid_1", + checkparams={'myid':5, 'myid_1':7}) + self.assert_compile( + update(table1, table1.c.myid == 7), + "UPDATE mytable SET name=:name WHERE mytable.myid = :myid_1", + params = {'name':'fred'}) + self.assert_compile( + update(table1, values = {table1.c.name : table1.c.myid}), + "UPDATE mytable SET name=mytable.myid") + self.assert_compile( + update(table1, + whereclause = table1.c.name == bindparam('crit'), + values = {table1.c.name : 'hi'}), + "UPDATE mytable SET name=:name WHERE mytable.name = :crit", + params = {'crit' : 'notthere'}, + checkparams={'crit':'notthere', 'name':'hi'}) + self.assert_compile( + update(table1, table1.c.myid == 12, + values = {table1.c.name : table1.c.myid}), + "UPDATE mytable SET name=mytable.myid, description=" + ":description WHERE mytable.myid = :myid_1", + params = {'description':'test'}, + checkparams={'description':'test', 'myid_1':12}) + self.assert_compile( + update(table1, table1.c.myid == 12, + values = {table1.c.myid : 9}), + "UPDATE mytable SET myid=:myid, description=:description " + "WHERE mytable.myid = :myid_1", + params = {'myid_1': 12, 'myid': 9, 'description': 'test'}) + self.assert_compile( + update(table1, table1.c.myid ==12), + "UPDATE mytable SET myid=:myid WHERE mytable.myid = :myid_1", + params={'myid':18}, checkparams={'myid':18, 'myid_1':12}) s = table1.update(table1.c.myid == 12, values = {table1.c.name : 'lala'}) c = s.compile(column_keys=['id', 'name']) - self.assert_compile(update(table1, table1.c.myid == 12, values = {table1.c.name : table1.c.myid}).values({table1.c.name:table1.c.name + 'foo'}), "UPDATE mytable SET name=(mytable.name || :name_1), description=:description WHERE mytable.myid = :myid_1", params = {'description':'test'}) - self.assert_(str(s) == str(c)) + self.assert_compile( + update(table1, table1.c.myid == 12, + values = {table1.c.name : table1.c.myid} + ).values({table1.c.name:table1.c.name + 'foo'}), + "UPDATE mytable SET name=(mytable.name || :name_1), " + "description=:description WHERE mytable.myid = :myid_1", + params = {'description':'test'}) + eq_(str(s), str(c)) self.assert_compile(update(table1, (table1.c.myid == func.hoho(4)) & @@ -1820,28 +1979,45 @@ class CRUDTest(TestBase, AssertsCompiledSQL): values = { table1.c.name : table1.c.name + "lala", table1.c.myid : func.do_stuff(table1.c.myid, literal('hoho')) - }), "UPDATE mytable SET myid=do_stuff(mytable.myid, :param_1), name=(mytable.name || :name_1) " - "WHERE mytable.myid = hoho(:hoho_1) AND mytable.name = :param_2 || mytable.name || :param_3") + }), "UPDATE mytable SET myid=do_stuff(mytable.myid, :param_1), " + "name=(mytable.name || :name_1) " + "WHERE mytable.myid = hoho(:hoho_1) AND mytable.name = :param_2 || " + "mytable.name || :param_3") def test_correlated_update(self): # test against a straight text subquery - u = update(table1, values = {table1.c.name : text("(select name from mytable where id=mytable.id)")}) - self.assert_compile(u, "UPDATE mytable SET name=(select name from mytable where id=mytable.id)") + u = update(table1, values = { + table1.c.name : + text("(select name from mytable where id=mytable.id)")}) + self.assert_compile(u, + "UPDATE mytable SET name=(select name from mytable " + "where id=mytable.id)") mt = table1.alias() - u = update(table1, values = {table1.c.name : select([mt.c.name], mt.c.myid==table1.c.myid)}) - self.assert_compile(u, "UPDATE mytable SET name=(SELECT mytable_1.name FROM mytable AS mytable_1 WHERE mytable_1.myid = mytable.myid)") + u = update(table1, values = { + table1.c.name : + select([mt.c.name], mt.c.myid==table1.c.myid) + }) + self.assert_compile(u, + "UPDATE mytable SET name=(SELECT mytable_1.name FROM " + "mytable AS mytable_1 WHERE mytable_1.myid = mytable.myid)") # test against a regular constructed subquery s = select([table2], table2.c.otherid == table1.c.myid) u = update(table1, table1.c.name == 'jack', values = {table1.c.name : s}) - self.assert_compile(u, "UPDATE mytable SET name=(SELECT myothertable.otherid, myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid) WHERE mytable.name = :name_1") + self.assert_compile(u, + "UPDATE mytable SET name=(SELECT myothertable.otherid, " + "myothertable.othername FROM myothertable WHERE " + "myothertable.otherid = mytable.myid) WHERE mytable.name = :name_1") # test a non-correlated WHERE clause s = select([table2.c.othername], table2.c.otherid == 7) u = update(table1, table1.c.name==s) - self.assert_compile(u, "UPDATE mytable SET myid=:myid, name=:name, description=:description WHERE mytable.name = "\ - "(SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = :otherid_1)") + self.assert_compile(u, + "UPDATE mytable SET myid=:myid, name=:name, " + "description=:description WHERE mytable.name = " + "(SELECT myothertable.othername FROM myothertable " + "WHERE myothertable.otherid = :otherid_1)") # test one that is actually correlated... s = select([table2.c.othername], table2.c.otherid == table1.c.myid)