]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- added verbose activity to profiling.function_call_count
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 1 Apr 2008 22:36:40 +0000 (22:36 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 1 Apr 2008 22:36:40 +0000 (22:36 +0000)
- simplified oracle non-ansi join generation, removed hooks from base compiler
- removed join() call from _label generation, fixed repeat label gen

lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
test/dialect/oracle.py
test/testlib/profiling.py

index 2763972649e4405e3f157b5da7acce73d23cc940..c601ad4e39469e701f0c7f8e03206a9582c06b7c 100644 (file)
@@ -570,7 +570,9 @@ class _OuterJoinColumn(sql.ClauseElement):
     __visit_name__ = 'outer_join_column'
     def __init__(self, column):
         self.column = column
-
+    def _get_from_objects(self, **kwargs):
+        return []
+    
 class OracleCompiler(compiler.DefaultCompiler):
     """Oracle compiler modifies the lexical structure of Select
     statements to work under non-ANSI configured Oracle databases, if
@@ -609,36 +611,28 @@ class OracleCompiler(compiler.DefaultCompiler):
     def visit_join(self, join, **kwargs):
         if self.dialect.use_ansi:
             return compiler.DefaultCompiler.visit_join(self, join, **kwargs)
-
-        (where, parentjoin) = self.__wheres.get(join, (None, None))
-
-        class VisitOn(visitors.ClauseVisitor):
-            def visit_binary(s, binary):
-                if binary.operator == sql_operators.eq:
-                    if binary.left.table is join.right:
-                        binary.left = _OuterJoinColumn(binary.left)
-                    elif binary.right.table is join.right:
-                        binary.right = _OuterJoinColumn(binary.right)
-
-        if join.isouter:
-            if where is not None:
-                self.__wheres[join.left] = self.__wheres[parentjoin] = (sql.and_(VisitOn().traverse(join.onclause, clone=True), where), parentjoin)
-            else:
-                self.__wheres[join.left] = self.__wheres[join] = (VisitOn().traverse(join.onclause, clone=True), join)
         else:
-            if where is not None:
-                self.__wheres[join.left] = self.__wheres[parentjoin] = (sql.and_(join.onclause, where), parentjoin)
+            return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True)
+    
+    def _get_nonansi_join_whereclause(self, froms):
+        clauses = []
+        
+        def visit_join(join):
+            if join.isouter:
+                def visit_binary(binary):
+                    if binary.operator == sql_operators.eq:
+                        if binary.left.table is join.right:
+                            binary.left = _OuterJoinColumn(binary.left)
+                        elif binary.right.table is join.right:
+                            binary.right = _OuterJoinColumn(binary.right)
+                clauses.append(visitors.traverse(join.onclause, visit_binary=visit_binary, clone=True))
             else:
-                self.__wheres[join.left] = self.__wheres[join] = (join.onclause, join)
-
-        return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True)
-
-    def get_whereclause(self, f):
-        if f in self.__wheres:
-            return self.__wheres[f][0]
-        else:
-            return None
-
+                clauses.append(join.onclause)
+        
+        for f in froms:
+            visitors.traverse(f, visit_join=visit_join)
+        return sql.and_(*clauses)
+        
     def visit_outer_join_column(self, vc):
         return self.process(vc.column) + "(+)"
 
@@ -662,27 +656,43 @@ class OracleCompiler(compiler.DefaultCompiler):
         so tries to wrap it in a subquery with ``row_number()`` criterion.
         """
 
-        if not getattr(select, '_oracle_visit', None) and (select._limit is not None or select._offset is not None):
-            # to use ROW_NUMBER(), an ORDER BY is required.
-            orderby = self.process(select._order_by_clause)
-            if not orderby:
-                orderby = list(select.oid_column.proxies)[0]
-                orderby = self.process(orderby)
-
-            oldselect = select
-            select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")).order_by(None)
-            select._oracle_visit = True
-
-            limitselect = sql.select([c for c in select.c if c.key!='ora_rn'])
-            if select._offset is not None:
-                limitselect.append_whereclause("ora_rn>%d" % select._offset)
-                if select._limit is not None:
-                    limitselect.append_whereclause("ora_rn<=%d" % (select._limit + select._offset))
-            else:
-                limitselect.append_whereclause("ora_rn<=%d" % select._limit)
-            return self.process(limitselect, iswrapper=True, **kwargs)
-        else:
-            return compiler.DefaultCompiler.visit_select(self, select, **kwargs)
+        if not getattr(select, '_oracle_visit', None):
+            if not self.dialect.use_ansi:
+                if self.stack and 'from' in self.stack[-1]:
+                    existingfroms = self.stack[-1]['from']
+                else:
+                    existingfroms = None
+
+                froms = select._get_display_froms(existingfroms)
+                whereclause = self._get_nonansi_join_whereclause(froms)
+                if whereclause:
+                    select = select.where(whereclause)
+                    select._oracle_visit = True
+                
+            if select._limit is not None or select._offset is not None:
+                # to use ROW_NUMBER(), an ORDER BY is required.
+                orderby = self.process(select._order_by_clause)
+                if not orderby:
+                    orderby = list(select.oid_column.proxies)[0]
+                    orderby = self.process(orderby)
+
+                select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")).order_by(None)
+                select._oracle_visit = True
+                
+                limitselect = sql.select([c for c in select.c if c.key!='ora_rn'])
+                limitselect._oracle_visit = True
+                limitselect._is_wrapper = True
+                
+                if select._offset is not None:
+                    limitselect.append_whereclause("ora_rn>%d" % select._offset)
+                    if select._limit is not None:
+                        limitselect.append_whereclause("ora_rn<=%d" % (select._limit + select._offset))
+                else:
+                    limitselect.append_whereclause("ora_rn<=%d" % select._limit)
+                select = limitselect
+        
+        kwargs['iswrapper'] = getattr(select, '_is_wrapper', False)
+        return compiler.DefaultCompiler.visit_select(self, select, **kwargs)
 
     def limit_clause(self, select):
         return ""
index 76e2ca2608eab5b9ebd388d67d098ac55624f8e9..800881e49a6c1aff6349b95042e6cb3c8efaf953 100644 (file)
@@ -193,16 +193,6 @@ class DefaultCompiler(engine.Compiled):
     def is_subquery(self, select):
         return self.stack and self.stack[-1].get('is_subquery')
 
-    def get_whereclause(self, obj):
-        """given a FROM clause, return an additional WHERE condition that should be
-        applied to a SELECT.
-
-        Currently used by Oracle to provide WHERE criterion for JOIN and OUTER JOIN
-        constructs in non-ansi mode.
-        """
-
-        return None
-
     def construct_params(self, params=None):
         """return a dictionary of bind parameter keys and values"""
 
@@ -529,25 +519,17 @@ class DefaultCompiler(engine.Compiled):
         text += self.get_select_precolumns(select)
         text += collist
 
-        whereclause = select._whereclause
-
         from_strings = []
         for f in froms:
             from_strings.append(self.process(f, asfrom=True))
 
-            w = self.get_whereclause(f)
-            if w is not None:
-                if whereclause is not None:
-                    whereclause = sql.and_(w, whereclause)
-                else:
-                    whereclause = w
-
         if froms:
             text += " \nFROM "
             text += string.join(from_strings, ', ')
         else:
             text += self.default_from()
 
+        whereclause = select._whereclause
         if whereclause is not None:
             t = self.process(whereclause)
             if t:
index 758f75ebe7335cf3d2a2ea565b63fc65b9ae3377..d8a85d8bd85bb8612373778013b6f5a29e04fe6b 100644 (file)
@@ -2630,16 +2630,20 @@ class _ColumnClause(ColumnElement):
         # therefore no 'label' can be automatically generated
         if self.is_literal:
             return None
-        if self.__label is None:
-            if self.table is not None and self.table.named_with_column:
+        if not self.__label:
+            if self.table and self.table.named_with_column:
                 if getattr(self.table, 'schema', None):
-                    self.__label = "_".join([self.table.schema, self.table.name, self.name])
+                    self.__label = self.table.schema + "_" + self.table.name + "_" + self.name
                 else:
-                    self.__label = "_".join([self.table.name, self.name])
-                counter = 1
-                while self.__label in self.table.c:
-                    self.__label = self.__label + "_%d" % counter
-                    counter += 1
+                    self.__label = self.table.name + "_" + self.name
+                    
+                if self.__label in self.table.c:
+                    label = self.__label
+                    counter = 1
+                    while label in self.table.c:
+                        label = self.__label + "_" + str(counter)
+                        counter +=1
+                    self.__label = label
             else:
                 self.__label = self.name
         return self.__label
index 4b9745a2d29bca557d45777dd10e7f97cf8c6173..fda6f3549f63a4adbc2c4bf4f07bed46db69e131 100644 (file)
@@ -112,9 +112,10 @@ class CompileTest(TestBase, AssertsCompiledSQL):
                 )
         self.assert_compile(query,
             "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername \
-FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid(+) AND \
+FROM mytable, myothertable WHERE \
 (mytable.name = :mytable_name_1 OR mytable.myid = :mytable_myid_1 OR \
-myothertable.othername != :myothertable_othername_1 OR EXISTS (select yay from foo where boo = lar))",
+myothertable.othername != :myothertable_othername_1 OR EXISTS (select yay from foo where boo = lar)) \
+AND mytable.myid = myothertable.otherid(+)",
             dialect=oracle.OracleDialect(use_ansi = False))
 
         query = table1.outerjoin(table2, table1.c.myid==table2.c.otherid).outerjoin(table3, table3.c.userid==table2.c.otherid)
@@ -124,6 +125,15 @@ myothertable.othername != :myothertable_othername_1 OR EXISTS (select yay from f
         query = table1.join(table2, table1.c.myid==table2.c.otherid).join(table3, table3.c.userid==table2.c.otherid)
         self.assert_compile(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable, myothertable, thirdtable WHERE mytable.myid = myothertable.otherid AND thirdtable.userid = myothertable.otherid", dialect=oracle.dialect(use_ansi=False))
 
+        query = table1.join(table2, table1.c.myid==table2.c.otherid).outerjoin(table3, table3.c.userid==table2.c.otherid)
+        self.assert_compile(query.select().limit(10).offset(5), "SELECT myid, name, description, otherid, othername, userid, \
+otherstuff FROM (SELECT mytable.myid AS myid, mytable.name AS name, \
+mytable.description AS description, myothertable.otherid AS otherid, \
+myothertable.othername AS othername, thirdtable.userid AS userid, \
+thirdtable.otherstuff AS otherstuff, ROW_NUMBER() OVER (ORDER BY mytable.rowid) AS ora_rn \
+FROM mytable, myothertable, thirdtable WHERE mytable.myid = myothertable.otherid AND thirdtable.userid(+) = myothertable.otherid) \
+WHERE ora_rn>5 AND ora_rn<=15", dialect=oracle.dialect(use_ansi=False))
+
     def test_alias_outer_join(self):
         address_types = table('address_types',
                     column('id'),
index 23e2e32145b3164348c372384943d60295077c66..54a96db47086a17529e5cf27e2ab24e1e820a56a 100644 (file)
@@ -142,6 +142,11 @@ def function_call_count(count=None, versions={}, variance=0.05):
                         "Function call count %s not within %s%% "
                         "of expected %s. (Python version %s)" % (
                         calls, (variance * 100), count, py_version))
+
+                if testlib.config.options.verbose:
+                    stats.sort_stats('calls', 'cumulative')
+                    stats.print_stats()
+
                 return result
             finally:
                 if os.path.exists(filename):