]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- The argument to "ESCAPE" of a LIKE operator or similar
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 24 Jun 2010 16:19:15 +0000 (12:19 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 24 Jun 2010 16:19:15 +0000 (12:19 -0400)
is passed through render_literal_value(), which may
implement escaping of backslashes.  [ticket:1400]
- Postgresql render_literal_value() is overridden which escapes
backslashes, currently applies to the ESCAPE clause
of LIKE and similar expressions.
Ultimately this will have to detect the value of
"standard_conforming_strings" for full behavior.
[ticket:1400]
- MySQL render_literal_value() is overridden which escapes
backslashes, currently applies to the ESCAPE clause
of LIKE and similar expressions.   This behavior
is derived from detecting the value of
NO_BACKSLASH_ESCAPES.  [ticket:1400]

CHANGES
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/types.py
test/dialect/test_mysql.py
test/sql/test_query.py

diff --git a/CHANGES b/CHANGES
index 3347fe397387e117e43e8bb3bcf224a118db1731..b138210a74a49ef0b642471f63eb518ad276a629 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -59,6 +59,10 @@ CHANGES
     is emitted once as per the warning filter settings, 
     and large string values don't pollute the output.
     [ticket:1822]
+
+  - The argument to "ESCAPE" of a LIKE operator or similar
+    is passed through render_literal_value(), which may 
+    implement escaping of backslashes.  [ticket:1400]
     
   - Fixed bug in Enum type which blew away native_enum
     flag when used with TypeDecorators or other adaption
@@ -78,11 +82,28 @@ CHANGES
     among others, fixes [ticket:1829] regarding declarative
     mixins
 
+- postgresql
+  - render_literal_value() is overridden which escapes
+    backslashes, currently applies to the ESCAPE clause
+    of LIKE and similar expressions.   
+    Ultimately this will have to detect the value of 
+    "standard_conforming_strings" for full behavior.  
+    [ticket:1400]
+  
 - mysql
   - MySQL dialect doesn't emit CAST() for MySQL version 
     detected < 4.0.2.  This allows the unicode
     check on connect to proceed. [ticket:1826]
 
+  - MySQL dialect now detects NO_BACKSLASH_ESCAPES sql
+    mode, in addition to ANSI_QUOTES.  
+    
+  - render_literal_value() is overridden which escapes
+    backslashes, currently applies to the ESCAPE clause
+    of LIKE and similar expressions.   This behavior
+    is derived from detecting the value of 
+    NO_BACKSLASH_ESCAPES.  [ticket:1400]
+    
 - oracle:
   - Fixed ora-8 compatibility flags such that they
     don't cache a stale value from before the first
index c4af013fc29d09a53501d1d976781e9f84b8202d..46e29694f421b3ba6400dc1939c21276e744d5a0 100644 (file)
@@ -1196,6 +1196,12 @@ class MySQLCompiler(compiler.SQLCompiler):
 
         return 'CAST(%s AS %s)' % (self.process(cast.clause), type_)
 
+    def render_literal_value(self, value, type_):
+        value = super(MySQLCompiler, self).render_literal_value(value, type_)
+        if self.dialect._backslash_escapes:
+            value = value.replace('\\', '\\\\')
+        return value
+        
     def get_select_precolumns(self, select):
         if isinstance(select._distinct, basestring):
             return select._distinct.upper() + " "
@@ -1639,6 +1645,12 @@ class MySQLDialect(default.DefaultDialect):
     ischema_names = ischema_names
     preparer = MySQLIdentifierPreparer
     
+    # default SQL compilation settings -
+    # these are modified upon initialize(), 
+    # i.e. first connect
+    _backslash_escapes = True
+    _server_ansiquotes = False
+    
     def __init__(self, use_ansiquotes=None, **kwargs):
         default.DefaultDialect.__init__(self, **kwargs)
 
@@ -1760,7 +1772,7 @@ class MySQLDialect(default.DefaultDialect):
         self._connection_charset = self._detect_charset(connection)
         self._server_casing = self._detect_casing(connection)
         self._server_collations = self._detect_collations(connection)
-        self._server_ansiquotes = self._detect_ansiquotes(connection)
+        self._detect_ansiquotes(connection)
         if self._server_ansiquotes:
             # if ansiquotes == True, build a new IdentifierPreparer
             # with the new setting
@@ -2019,8 +2031,11 @@ class MySQLDialect(default.DefaultDialect):
                 mode_no = int(mode)
                 mode = (mode_no | 4 == mode_no) and 'ANSI_QUOTES' or ''
 
-        return 'ANSI_QUOTES' in mode
-
+        self._server_ansiquotes = 'ANSI_QUOTES' in mode
+        
+        # as of MySQL 5.0.1
+        self._backslash_escapes = 'NO_BACKSLASH_ESCAPES' not in mode
+        
     def _show_create_table(self, connection, table, charset=None,
                            full_name=None):
         """Run SHOW CREATE TABLE for a ``Table``."""
index 76d1122e8d66c267e0067e4a338de2c2aa66a74a..8275aa1e7d37a12111e9216d6cf638708ed62a1e 100644 (file)
@@ -324,12 +324,23 @@ class PGCompiler(compiler.SQLCompiler):
     def visit_ilike_op(self, binary, **kw):
         escape = binary.modifiers.get("escape", None)
         return '%s ILIKE %s' % (self.process(binary.left), self.process(binary.right)) \
-            + (escape and ' ESCAPE \'%s\'' % escape or '')
+                + (escape and 
+                        (' ESCAPE ' + self.render_literal_value(escape, None))
+                        or '')
 
     def visit_notilike_op(self, binary, **kw):
         escape = binary.modifiers.get("escape", None)
         return '%s NOT ILIKE %s' % (self.process(binary.left), self.process(binary.right)) \
-            + (escape and ' ESCAPE \'%s\'' % escape or '')
+                + (escape and 
+                        (' ESCAPE ' + self.render_literal_value(escape, None))
+                        or '')
+
+    def render_literal_value(self, value, type_):
+        value = super(PGCompiler, self).render_literal_value(value, type_)
+        # TODO: need to inspect "standard_conforming_strings"
+        if self.dialect._backslash_escapes:
+            value = value.replace('\\', '\\\\')
+        return value
 
     def visit_sequence(self, seq):
         if seq.optional:
@@ -625,6 +636,9 @@ class PGDialect(default.DefaultDialect):
     inspector = PGInspector
     isolation_level = None
 
+    # TODO: need to inspect "standard_conforming_strings"
+    _backslash_escapes = True
+
     def __init__(self, isolation_level=None, **kwargs):
         default.DefaultDialect.__init__(self, **kwargs)
         self.isolation_level = isolation_level
index b4992eec330685844448d7711fb1ef34c8e4b005..c54931b87b0046877de24795ffb31c2bea39f2be 100644 (file)
@@ -494,28 +494,36 @@ class SQLCompiler(engine.Compiled):
         return '%s LIKE %s' % (
                                     self.process(binary.left, **kw), 
                                     self.process(binary.right, **kw)) \
-            + (escape and ' ESCAPE \'%s\'' % escape or '')
+            + (escape and 
+                    (' ESCAPE ' + self.render_literal_value(escape, None))
+                    or '')
 
     def visit_notlike_op(self, binary, **kw):
         escape = binary.modifiers.get("escape", None)
         return '%s NOT LIKE %s' % (
                                     self.process(binary.left, **kw), 
                                     self.process(binary.right, **kw)) \
-            + (escape and ' ESCAPE \'%s\'' % escape or '')
+            + (escape and 
+                    (' ESCAPE ' + self.render_literal_value(escape, None))
+                    or '')
         
     def visit_ilike_op(self, binary, **kw):
         escape = binary.modifiers.get("escape", None)
         return 'lower(%s) LIKE lower(%s)' % (
                                             self.process(binary.left, **kw), 
                                             self.process(binary.right, **kw)) \
-            + (escape and ' ESCAPE \'%s\'' % escape or '')
+            + (escape and 
+                    (' ESCAPE ' + self.render_literal_value(escape, None))
+                    or '')
     
     def visit_notilike_op(self, binary, **kw):
         escape = binary.modifiers.get("escape", None)
         return 'lower(%s) NOT LIKE lower(%s)' % (
                                             self.process(binary.left, **kw), 
                                             self.process(binary.right, **kw)) \
-            + (escape and ' ESCAPE \'%s\'' % escape or '')
+            + (escape and 
+                    (' ESCAPE ' + self.render_literal_value(escape, None))
+                    or '')
         
     def _operator_dispatch(self, operator, element, fn, **kw):
         if util.callable(operator):
index 84bd85ff2ddb9b001b71a26997d3a07338334619..d8a176f1fdf7eceb20e9b115874512d093e7faf4 100644 (file)
@@ -1796,6 +1796,7 @@ class BOOLEAN(Boolean):
 
 NULLTYPE = NullType()
 BOOLEANTYPE = Boolean()
+STRINGTYPE = String()
 
 # using VARCHAR/NCHAR so that we dont get the genericized "String"
 # type which usually resolves to TEXT/CLOB
index f41f06209597a749e8551c2effed601aa341cd18..964428cdf65cfbc3c989330e1120947a9592f8a3 100644 (file)
@@ -1017,7 +1017,21 @@ class SQLTest(TestBase, AssertsCompiledSQL):
         eq_(
             gen(True, ['high_priority', sql.text('sql_cache')]),
             'SELECT high_priority sql_cache DISTINCT q')
+        
+    def test_backslash_escaping(self):
+        self.assert_compile(
+            sql.column('foo').like('bar', escape='\\'),
+            "foo LIKE %s ESCAPE '\\\\'"
+        )
 
+        dialect = mysql.dialect()
+        dialect._backslash_escapes=False
+        self.assert_compile(
+            sql.column('foo').like('bar', escape='\\'),
+            "foo LIKE %s ESCAPE '\\'",
+            dialect=dialect
+        )
+        
     def test_limit(self):
         t = sql.table('t', sql.column('col1'), sql.column('col2'))
 
@@ -1221,7 +1235,46 @@ class SQLTest(TestBase, AssertsCompiledSQL):
             ")ENGINE=InnoDB"
         )
 
-
+class SQLModeDetectionTest(TestBase):
+    __only_on__ = 'mysql'
+    
+    def _options(self, modes):
+        class SetOptions(object):
+            def first_connect(self, con, record):
+                self.connect(con, record)
+            def connect(self, con, record):
+                cursor = con.cursor()
+                cursor.execute("set sql_mode='%s'" % (",".join(modes)))
+        return engines.testing_engine(options={"listeners":[SetOptions()]})
+        
+    def test_backslash_escapes(self):
+        engine = self._options(['NO_BACKSLASH_ESCAPES'])
+        c = engine.connect()
+        assert not engine.dialect._backslash_escapes
+        c.close()
+        engine.dispose()
+
+        engine = self._options([])
+        c = engine.connect()
+        assert engine.dialect._backslash_escapes
+        c.close()
+        engine.dispose()
+
+    def test_ansi_quotes(self):
+        engine = self._options(['ANSI_QUOTES'])
+        c = engine.connect()
+        assert engine.dialect._server_ansiquotes
+        c.close()
+        engine.dispose()
+
+    def test_combination(self):
+        engine = self._options(['ANSI_QUOTES,NO_BACKSLASH_ESCAPES'])
+        c = engine.connect()
+        assert engine.dialect._server_ansiquotes
+        assert not engine.dialect._backslash_escapes
+        c.close()
+        engine.dispose()
+        
 class RawReflectionTest(TestBase):
     def setup(self):
         dialect = mysql.dialect()
index 2b51d68a26c75a0f83515c88d2212e19ebbaa586..e8f9d118b3d87f50c8fed90f09b42a9c89c79b01 100644 (file)
@@ -376,13 +376,21 @@ class QueryTest(TestBase):
         )
 
         for expr, result in (
-            (select([users.c.user_id]).where(users.c.user_name.startswith('apple')), [(1,)]),
-            (select([users.c.user_id]).where(users.c.user_name.contains('i % t')), [(5,)]),
-            (select([users.c.user_id]).where(users.c.user_name.endswith('anas')), [(3,)]),
+            (select([users.c.user_id]).\
+                    where(users.c.user_name.startswith('apple')), [(1,)]),
+            (select([users.c.user_id]).\
+                    where(users.c.user_name.contains('i % t')), [(5,)]),
+            (select([users.c.user_id]).\
+                    where(
+                        users.c.user_name.endswith('anas')
+                    ), [(3,)]),
+            (select([users.c.user_id]).\
+                    where(
+                        users.c.user_name.contains('i % t', escape='\\')
+                    ), [(5,)]),
         ):
             eq_(expr.execute().fetchall(), result)
     
-
     @testing.fails_on("firebird", "see dialect.test_firebird:MiscTest.test_percents_in_text")
     @testing.fails_on("oracle", "neither % nor %% are accepted")
     @testing.fails_on("+pg8000", "can't interpret result column from '%%'")