]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Try to only convert :bind params and leave colons in text literals alone
authorJason Kirtland <jek@discorporate.us>
Tue, 10 Jul 2007 21:53:03 +0000 (21:53 +0000)
committerJason Kirtland <jek@discorporate.us>
Tue, 10 Jul 2007 21:53:03 +0000 (21:53 +0000)
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/sql.py
test/sql/select.py

index d2e779a7e65062fe4233cf3573e4f6b09b3b3319..188063a82da1346fbe324137a14d4d8f57bee14e 100644 (file)
@@ -40,6 +40,9 @@ RESERVED_WORDS = util.Set(['all', 'analyse', 'analyze', 'and', 'any', 'array',
 LEGAL_CHARACTERS = util.Set(string.ascii_lowercase + string.ascii_uppercase + string.digits + '_$')
 ILLEGAL_INITIAL_CHARACTERS = util.Set(string.digits + '$')
 
+BIND_PARAMS = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
+BIND_PARAMS_ESC = re.compile(r'\x5c(:\w+)(?!:)', re.UNICODE)
+
 class ANSIDialect(default.DefaultDialect):
     def __init__(self, cache_identifiers=True, **kwargs):
         super(ANSIDialect,self).__init__(**kwargs)
@@ -177,23 +180,29 @@ class ANSICompiler(engine.Compiled):
         # this re will search for params like :param
         # it has a negative lookbehind for an extra ':' so that it doesnt match
         # postgres '::text' tokens
-        match = re.compile(r'(?<!:):([\w_]+)', re.UNICODE)
+        text = self.strings[self.statement]
+        if ':' not in text:
+            return
+        
         if self.paramstyle=='pyformat':
-            self.strings[self.statement] = match.sub(lambda m:'%(' + m.group(1) +')s', self.strings[self.statement])
+            text = BIND_PARAMS.sub(lambda m:'%(' + m.group(1) +')s', text)
         elif self.positional:
-            params = match.finditer(self.strings[self.statement])
+            params = BIND_PARAMS.finditer(text)
             for p in params:
                 self.positiontup.append(p.group(1))
             if self.paramstyle=='qmark':
-                self.strings[self.statement] = match.sub('?', self.strings[self.statement])
+                text = BIND_PARAMS.sub('?', text)
             elif self.paramstyle=='format':
-                self.strings[self.statement] = match.sub('%s', self.strings[self.statement])
+                text = BIND_PARAMS.sub('%s', text)
             elif self.paramstyle=='numeric':
                 i = [0]
                 def getnum(x):
                     i[0] += 1
                     return str(i[0])
-                self.strings[self.statement] = match.sub(getnum, self.strings[self.statement])
+                text = BIND_PARAMS.sub(getnum, text)
+        # un-escape any \:params
+        text = BIND_PARAMS_ESC.sub(lambda m: m.group(1), text)
+        self.strings[self.statement] = text
 
     def get_from_text(self, obj):
         return self.froms.get(obj, None)
index a4808623e0f623531b61923e3e77086f8933ec37..0abb28a7e12142963f2ed60569847be66d56b901 100644 (file)
@@ -71,6 +71,7 @@ PRECEDENCE = {
     '_smallest': -1000,
     '_largest': 1000
 }
+BIND_PARAMS = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
 
 def desc(column):
     """Return a descending ``ORDER BY`` clause element.
@@ -1765,7 +1766,7 @@ class _TextClause(ClauseElement):
         
         # scan the string and search for bind parameter names, add them
         # to the list of bindparams
-        self.text = re.compile(r'(?<!:):([\w_]+)', re.S).sub(repl, text)
+        self.text = BIND_PARAMS.sub(repl, text)
         if bindparams is not None:
             for b in bindparams:
                 self.bindparams[b.key] = b
index 7d9a2a8ca3469a68299945be897f06a8cabc3794..8c1b9da7d659b7eb6e5c090bfd97b4c3c4078b1d 100644 (file)
@@ -479,7 +479,14 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid =
                 checkparams={'bar':4, 'whee': 7},
                 params={'bar':4, 'whee': 7, 'hoho':10},
         )
-        
+
+        self.runtest(
+            text("select * from foo where clock='05:06:07'"), 
+                "select * from foo where clock='05:06:07'", 
+                checkparams={},
+                params={},
+        )
+
         dialect = postgres.dialect()
         self.runtest(
             text("select * from foo where lala=:bar and hoho=:whee"), 
@@ -488,6 +495,13 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid =
                 params={'bar':4, 'whee': 7, 'hoho':10},
                 dialect=dialect
         )
+        self.runtest(
+            text("select * from foo where clock='05:06:07' and mork='\:mindy'"),
+            "select * from foo where clock='05:06:07' and mork=':mindy'",
+            checkparams={},
+            params={},
+            dialect=dialect
+        )
 
         dialect = sqlite.dialect()
         self.runtest(