]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- removed regular expression step from most statement compilations.
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 27 Oct 2007 17:41:30 +0000 (17:41 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 27 Oct 2007 17:41:30 +0000 (17:41 +0000)
  also fixes [ticket:833]
- inlining on PG with_returning() call
- extra options added for profiling

CHANGES
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/sql/compiler.py
test/dialect/sqlite.py
test/testlib/profiling.py

diff --git a/CHANGES b/CHANGES
index cbc67819f9626658a140acc023c7d3a1c3058a80..5e4c52b020be81d46b2f6f9e4b0b045bfc533aa4 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -4,6 +4,8 @@ CHANGES
 
 0.4.1
 -----
+- removed regular expression step from most statement compilations.
+  also fixes [ticket:833]
 
 - Added test coverage for unknown type reflection, fixed sqlite/mysql
   handling of type reflection for unknown types.
index 018074b67755a3361d7a6d6768f457266c87831f..6449bdc1447e2b0b9e48433a8d5280a0d3928d49 100644 (file)
@@ -648,27 +648,31 @@ class PGCompiler(compiler.DefaultCompiler):
             return super(PGCompiler, self).for_update_clause(select)
 
     def _append_returning(self, text, stmt):
-        returning_cols = stmt.kwargs.get('postgres_returning', None)
-        if returning_cols:
-            def flatten_columnlist(collist):
-                for c in collist:
-                    if isinstance(c, expression.Selectable):
-                        for co in c.columns:
-                            yield co
-                    else:
-                        yield c
-            columns = [self.process(c) for c in flatten_columnlist(returning_cols)]
-            text += ' RETURNING ' + string.join(columns, ', ')
-        
+        returning_cols = stmt.kwargs['postgres_returning']
+        def flatten_columnlist(collist):
+            for c in collist:
+                if isinstance(c, expression.Selectable):
+                    for co in c.columns:
+                        yield co
+                else:
+                    yield c
+        columns = [self.process(c) for c in flatten_columnlist(returning_cols)]
+        text += ' RETURNING ' + string.join(columns, ', ')
         return text
 
     def visit_update(self, update_stmt):
         text = super(PGCompiler, self).visit_update(update_stmt)
-        return self._append_returning(text, update_stmt)
+        if 'postgres_returning' in update_stmt.kwargs:
+            return self._append_returning(text, update_stmt)
+        else:
+            return text
 
     def visit_insert(self, insert_stmt):
         text = super(PGCompiler, self).visit_insert(insert_stmt)
-        return self._append_returning(text, insert_stmt)
+        if 'postgres_returning' in insert_stmt.kwargs:
+            return self._append_returning(text, insert_stmt)
+        else:
+            return text
 
 class PGSchemaGenerator(compiler.SchemaGenerator):
     def get_column_specification(self, column, **kwargs):
index f2627eb85f0defd010a956bb9d10a79aed3cf0ae..e662f8e99d8f0b03a196fbb69d331c48fddabc80 100644 (file)
@@ -43,6 +43,15 @@ BIND_PARAMS = re.compile(r'(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])', re.UNICODE)
 BIND_PARAMS_ESC = re.compile(r'\x5c(:[\w\$]+)(?![:\w\$])', re.UNICODE)
 ANONYMOUS_LABEL = re.compile(r'{ANON (-?\d+) (.*)}')
 
+BIND_TEMPLATES = {
+    'pyformat':"%%(%(name)s)s",
+    'qmark':"?",
+    'format':"%%s",
+    'numeric':"%(position)s",
+    'named':":%(name)s"
+}
+    
+
 OPERATORS =  {
     operators.and_ : 'AND',
     operators.or_ : 'OR',
@@ -132,15 +141,14 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
         # for aliases
         self.generated_ids = {}
         
-        # default formatting style for bind parameters
-        self.bindtemplate = ":%s"
-
         # paramstyle from the dialect (comes from DB-API)
         self.paramstyle = self.dialect.paramstyle
 
         # true if the paramstyle is positional
         self.positional = self.dialect.positional
 
+        self.bindtemplate = BIND_TEMPLATES[self.paramstyle]
+        
         # a list of the compiled's bind parameter names, used to help
         # formulate a positional argument list
         self.positiontup = []
@@ -148,38 +156,8 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
         # an IdentifierPreparer that formats the quoting of identifiers
         self.preparer = self.dialect.identifier_preparer
         
-        
-    def after_compile(self):
-        # this re will search for params like :param
-        # it has a negative lookbehind for an extra ':' so that it doesnt match
-        # postgres '::text' tokens
-        text = self.string
-        if ':' not in text:
-            return
-        
-        if self.paramstyle=='pyformat':
-            text = BIND_PARAMS.sub(lambda m:'%(' + m.group(1) +')s', text)
-        elif self.positional:
-            params = BIND_PARAMS.finditer(text)
-            for p in params:
-                self.positiontup.append(p.group(1))
-            if self.paramstyle=='qmark':
-                text = BIND_PARAMS.sub('?', text)
-            elif self.paramstyle=='format':
-                text = BIND_PARAMS.sub('%s', text)
-            elif self.paramstyle=='numeric':
-                i = [0]
-                def getnum(x):
-                    i[0] += 1
-                    return str(i[0])
-                text = BIND_PARAMS.sub(getnum, text)
-        # un-escape any \:params
-        text = BIND_PARAMS_ESC.sub(lambda m: m.group(1), text)
-        self.string = text
-
     def compile(self):
         self.string = self.process(self.statement)
-        self.after_compile()
     
     def process(self, obj, stack=None, **kwargs):
         if stack:
@@ -291,11 +269,20 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
         return typeclause.type.dialect_impl(self.dialect).get_col_spec()
 
     def visit_textclause(self, textclause, **kwargs):
-        for bind in textclause.bindparams.values():
-            self.process(bind)
         if textclause.typemap is not None:
             self.typemap.update(textclause.typemap)
-        return textclause.text
+            
+        def do_bindparam(m):
+            name = m.group(1)
+            if name in textclause.bindparams:
+                return self.process(textclause.bindparams[name])
+            else:
+                return self.bindparam_string(name)
+
+        # un-escape any \:params
+        return BIND_PARAMS_ESC.sub(lambda m: m.group(1), 
+            BIND_PARAMS.sub(do_bindparam, textclause.text)
+        )
 
     def visit_null(self, null, **kwargs):
         return 'NULL'
@@ -437,7 +424,10 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
         return ANONYMOUS_LABEL.sub(self._process_anon, name)
             
     def bindparam_string(self, name):
-        return self.bindtemplate % name
+        if self.positional:
+            self.positiontup.append(name)
+            
+        return self.bindtemplate % {'name':name, 'position':len(self.positiontup)}
 
     def visit_alias(self, alias, asfrom=False, **kwargs):
         if asfrom:
index 285921588e948de8fb6279bdc474da12f32c3305..f3eac38f92dd12763295ba588bd340763f3e2219 100644 (file)
@@ -157,6 +157,21 @@ class InsertTest(AssertMixin):
                   Column('x', Integer),
                   Column('y', Integer)))
 
+    @testing.supported('sqlite')
+    def test_inserts_with_spaces(self):
+        tbl = Table('tbl', MetaData('sqlite:///'),
+                  Column('with space', Integer),
+                  Column('without', Integer))
+        tbl.create()
+        try:
+            tbl.insert().execute({'without':123})
+            assert list(tbl.select().execute()) == [(None, 123)]
 
+            tbl.insert().execute({'with space':456})
+            assert list(tbl.select().execute()) == [(None, 123), (456, None)]
+
+        finally:
+            tbl.drop()
+          
 if __name__ == "__main__":
     testbase.main()
index f2f75f66e86a9cdec05711a2137a731c71448484..947bf962e0db61638e7586c7a5b486ca2cde3408 100644 (file)
@@ -13,7 +13,7 @@ profile_config = { 'targets': set(),
                    'sort': ('time', 'calls'),
                    'limit': None }
 
-def profiled(target, **target_opts):
+def profiled(target=None, **target_opts):
     """Optional function profiling.
 
     @profiled('label')
@@ -28,7 +28,9 @@ def profiled(target, **target_opts):
     import time, hotshot, hotshot.stats
 
     # manual or automatic namespacing by module would remove conflict issues
-    if target in all_targets:
+    if target is None:
+        target = 'anonymous_target'
+    elif target in all_targets:
         print "Warning: redefining profile target '%s'" % target
     all_targets.add(target)
 
@@ -70,6 +72,8 @@ def profiled(target, **target_opts):
 
             assert_range = target_opts.get('call_range')
             if assert_range:
+                if isinstance(assert_range, dict):
+                    assert_range = assert_range.get(testlib.config.db, 'default')
                 stats = hotshot.stats.load(filename)
                 assert stats.total_calls >= assert_range[0] and stats.total_calls <= assert_range[1], stats.total_calls