]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- took out method calls for oid_column
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 4 Sep 2007 18:07:16 +0000 (18:07 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 4 Sep 2007 18:07:16 +0000 (18:07 +0000)
- reduced complexity of parameter handling during execution; __distill_params does all
parameter munging, executioncontext.parameters always holds a list of parameter structures
(lists, tuples, or dicts).

lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
test/profiling/zoomark.py
test/testlib/testing.py

index 32bf7b780fe5adb93edc88c23314eee3624eb3ba..cc2ea6d5dea7adf77e3c79978ee34452c49bc38c 100644 (file)
@@ -790,42 +790,43 @@ class Connection(Connectable):
         return context.result()
 
     def __distill_params(self, multiparams, params):
-        if multiparams is None or len(multiparams) == 0:
-            parameters = params or None
-        elif len(multiparams) == 1 and isinstance(multiparams[0], (list, tuple, dict)):
-            parameters = multiparams[0]
-        else:
-            parameters = list(multiparams)
-        return parameters
-
-    def __distill_params_and_keys(self, multiparams, params):
+        """given arguments from the calling form *multiparams, **params, return a list
+        of bind parameter structures, usually a list of dictionaries.  
+        
+        in the case of 'raw' execution which accepts positional parameters, 
+        it may be a list of tuples or lists."""
+        
         if multiparams is None or len(multiparams) == 0:
             if params:
-                parameters = params
-                keys = params.keys()
+                return [params]
             else:
-                parameters = None
-                keys = []
-            executemany = False
-        elif len(multiparams) == 1 and isinstance(multiparams[0], (list, tuple, dict)):
-            parameters = multiparams[0]
-            if isinstance(parameters, dict):
-                keys = parameters.keys()
+                return [{}]
+        elif len(multiparams) == 1:
+            if isinstance(multiparams[0], (list, tuple)):
+                if isinstance(multiparams[0][0], (list, tuple, dict)):
+                    return multiparams[0]
+                else:
+                    return [multiparams[0]]
+            elif isinstance(multiparams[0], dict):
+                return [multiparams[0]]
             else:
-                keys = parameters[0].keys()
-            executemany = False
+                return [[multiparams[0]]]
         else:
-            parameters = list(multiparams)
-            keys = parameters[0].keys()
-            executemany = True
-        return (parameters, keys, executemany)
+            if isinstance(multiparams[0], (list, tuple, dict)):
+                return multiparams
+            else:
+                return [multiparams]
 
     def _execute_function(self, func, multiparams, params):
         return self._execute_clauseelement(func.select(), multiparams, params)
 
     def _execute_clauseelement(self, elem, multiparams=None, params=None):
-        (params, keys, executemany) = self.__distill_params_and_keys(multiparams, params)
-        return self._execute_compiled(elem.compile(dialect=self.dialect, column_keys=keys, inline=executemany), distilled_params=params)
+        params = self.__distill_params(multiparams, params)
+        if params:
+            keys = params[0].keys()
+        else:
+            keys = None
+        return self._execute_compiled(elem.compile(dialect=self.dialect, column_keys=keys, inline=len(params) > 1), distilled_params=params)
 
     def _execute_compiled(self, compiled, multiparams=None, params=None, distilled_params=None):
         """Execute a sql.Compiled object."""
@@ -845,17 +846,10 @@ class Connection(Connectable):
         return self.__engine.dialect.create_execution_context(connection=self, **kwargs)
 
     def __execute_raw(self, context):
-        if context.parameters is not None and isinstance(context.parameters, list) and len(context.parameters) > 0 and isinstance(context.parameters[0], (list, tuple, dict)):
+        if context.executemany:
             self._cursor_executemany(context.cursor, context.statement, context.parameters, context=context)
         else:
-            if context.parameters is None:
-                if context.dialect.positional:
-                    parameters = ()
-                else:
-                    parameters = {}
-            else:
-                parameters = context.parameters
-            self._cursor_execute(context.cursor, context.statement, parameters, context=context)
+            self._cursor_execute(context.cursor, context.statement, context.parameters[0], context=context)
         self._autocommit(context)
 
     def _cursor_execute(self, cursor, statement, parameters, context=None):
@@ -1123,6 +1117,10 @@ class Engine(Connectable):
     def scalar(self, statement, *multiparams, **params):
         return self.execute(statement, *multiparams, **params).scalar()
 
+    def _execute_clauseelement(self, elem, multiparams=None, params=None):
+        connection = self.contextual_connect(close_with_result=True)
+        return connection._execute_clauseelement(elem, multiparams, params)
+
     def _execute_compiled(self, compiled, multiparams, params):
         connection = self.contextual_connect(close_with_result=True)
         return connection._execute_compiled(compiled, multiparams, params)
index 578b19d166a59719c41c5eb42e5437eb8a4cb4dc..07f07d0be6b0a51e086a294760402dbdcb1312cf 100644 (file)
@@ -141,28 +141,22 @@ class DefaultExecutionContext(base.ExecutionContext):
             self.statement = unicode(compiled)
             self.isinsert = compiled.isinsert
             self.isupdate = compiled.isupdate
-            if parameters is None:
-                self.compiled_parameters = compiled.construct_params()
-                self.executemany = False
-            elif not isinstance(parameters, (list, tuple)):
-                self.compiled_parameters = compiled.construct_params(parameters)
+            if not parameters:
+                self.compiled_parameters = [compiled.construct_params()]
                 self.executemany = False
             else:
                 self.compiled_parameters = [compiled.construct_params(m) for m in parameters]
-                if len(self.compiled_parameters) == 1:
-                    self.compiled_parameters = self.compiled_parameters[0]
-                    self.executemany = False
-                else:
-                    self.executemany = True
+                self.executemany = len(parameters) > 1
 
         elif statement is not None:
             self.typemap = self.column_labels = None
             self.parameters = self.__encode_param_keys(parameters)
+            self.executemany = len(parameters) > 1
             self.statement = statement
             self.isinsert = self.isupdate = False
         else:
             self.statement = None
-            self.isinsert = self.isupdate = False
+            self.isinsert = self.isupdate = self.executemany = False
             
         if self.statement is not None and not dialect.supports_unicode_statements:
             self.statement = self.statement.encode(self.dialect.encoding)
@@ -174,9 +168,17 @@ class DefaultExecutionContext(base.ExecutionContext):
     root_connection = property(lambda s:s._connection)
     
     def __encode_param_keys(self, params):
-        """apply string encoding to the keys of dictionary-based bind parameters"""
+        """apply string encoding to the keys of dictionary-based bind parameters.
+        
+        This is only used executing textual, non-compiled SQL expressions."""
+        
         if self.dialect.positional or self.dialect.supports_unicode_statements:
-            return params
+            if params:
+                return params
+            elif self.dialect.positional:
+                return [()]
+            else:
+                return [{}]
         else:
             def proc(d):
                 # sigh, sometimes we get positional arguments with a dialect
@@ -184,32 +186,15 @@ class DefaultExecutionContext(base.ExecutionContext):
                 if not isinstance(d, dict):
                     return d
                 return dict([(k.encode(self.dialect.encoding), d[k]) for k in d])
-            if isinstance(params, list):
-                return [proc(d) for d in params]
-            else:
-                return proc(params)
+            return [proc(d) for d in params] or [{}]
 
     def __convert_compiled_params(self, parameters):
-        encode = not self.dialect.supports_unicode_statements
-        # the bind params are a CompiledParams object.  but all the
-        # DB-API's hate that object (or similar).  so convert it to a
-        # clean dictionary/list/tuple of dictionary/tuple of list
-        if parameters is not None:
-            if self.executemany:
-                processors = parameters[0].get_processors()
-            else:
-                processors = parameters.get_processors()
-
-            if self.dialect.positional:
-                if self.executemany:
-                    parameters = [p.get_raw_list(processors) for p in parameters]
-                else:
-                    parameters = parameters.get_raw_list(processors)
-            else:
-                if self.executemany:
-                    parameters = [p.get_raw_dict(processors, encode_keys=encode) for p in parameters]
-                else:
-                    parameters = parameters.get_raw_dict(processors, encode_keys=encode)
+        processors = parameters[0].get_processors()
+        if self.dialect.positional:
+            parameters = [p.get_raw_list(processors) for p in parameters]
+        else:
+            encode = not self.dialect.supports_unicode_statements
+            parameters = [p.get_raw_dict(processors, encode_keys=encode) for p in parameters]
         return parameters
                 
     def is_select(self):
@@ -275,10 +260,7 @@ class DefaultExecutionContext(base.ExecutionContext):
         from the bind parameter's ``TypeEngine`` objects.
         """
 
-        if isinstance(self.compiled_parameters, list):
-            plist = self.compiled_parameters
-        else:
-            plist = [self.compiled_parameters]
+        plist = self.compiled_parameters
         if self.dialect.positional:
             inputsizes = []
             for params in plist[0:1]:
@@ -319,6 +301,7 @@ class DefaultExecutionContext(base.ExecutionContext):
                     self.compiled_parameters = params
                     
             else:
+                compiled_parameters = self.compiled_parameters[0]
                 drunner = self.dialect.defaultrunner(self)
                 if self.isinsert:
                     self._last_inserted_ids = []
@@ -328,18 +311,18 @@ class DefaultExecutionContext(base.ExecutionContext):
                     else:
                         val = drunner.get_column_onupdate(c)
                     if val is not None:
-                        self.compiled_parameters.set_value(c.key, val)
+                        compiled_parameters.set_value(c.key, val)
 
                 if self.isinsert:
-                    processors = self.compiled_parameters.get_processors()
+                    processors = compiled_parameters.get_processors()
                     for c in self.compiled.statement.table.primary_key:
-                        if c.key in self.compiled_parameters:
-                            self._last_inserted_ids.append(self.compiled_parameters.get_processed(c.key, processors))
+                        if c.key in compiled_parameters:
+                            self._last_inserted_ids.append(compiled_parameters.get_processed(c.key, processors))
                         else:
                             self._last_inserted_ids.append(None)
                             
                 self._postfetch_cols = self.compiled.postfetch
                 if self.isinsert:
-                    self._last_inserted_params = self.compiled_parameters
+                    self._last_inserted_params = compiled_parameters
                 else:
-                    self._last_updated_params = self.compiled_parameters
+                    self._last_updated_params = compiled_parameters
index eb416803a46c5147093d1b02d8985c53c340c523..ded493642b37f5d90561f369af22a875d460d933 100644 (file)
@@ -272,15 +272,15 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
             if column.table.oid_column is column:
                 n = self.dialect.oid_column_name(column)
                 if n is not None:
-                    return "%s.%s" % (self.preparer.format_table(column.table, use_schema=False, name=self._anonymize(column.table.name)), n)
+                    return "%s.%s" % (self.preparer.format_table(column.table, use_schema=False, name=ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)), n)
                 elif len(column.table.primary_key) != 0:
                     pk = list(column.table.primary_key)[0]
                     pkname = (pk.is_literal and name or self._truncated_identifier("colident", pk.name))
-                    return self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname, table_name=self._anonymize(column.table.name))
+                    return self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname, table_name=ANONYMOUS_LABEL.sub(self._process_anon, column.table.name))
                 else:
                     return None
             else:
-                return self.preparer.format_column_with_table(column, column_name=name, table_name=self._anonymize(column.table.name))
+                return self.preparer.format_column_with_table(column, column_name=name, table_name=ANONYMOUS_LABEL.sub(self._process_anon, column.table.name))
 
 
     def visit_fromclause(self, fromclause, **kwargs):
index f88c418ebf3b57987acee7849ff49f3c1a83fe60..f3c17236d87c77f9887fbf6757e136b126c6aad0 100644 (file)
@@ -963,17 +963,11 @@ class ClauseElement(object):
 
     def execute(self, *multiparams, **params):
         """Compile and execute this ``ClauseElement``."""
-        
-        if len(multiparams) == 0:
-            keys = params.keys()
-        elif isinstance(multiparams[0], dict):
-            keys = multiparams[0].keys()
-        elif isinstance(multiparams[0], (list, tuple)):
-            keys = multiparams[0][0].keys()
-        else:
-            keys = None
 
-        return self.compile(bind=self.bind, column_keys=keys, inline=(len(multiparams) > 1)).execute(*multiparams, **params)
+        e = self.bind
+        if e is None:
+            raise exceptions.InvalidRequestError("This Compiled object is not bound to any Engine or Connection.")
+        return e._execute_clauseelement(self, multiparams, params)
 
     def scalar(self, *multiparams, **params):
         """Compile and execute this ``ClauseElement``, returning the result's scalar representation."""
@@ -1516,7 +1510,8 @@ class FromClause(Selectable):
 
     def __init__(self, name=None):
         self.name = name
-
+        self.oid_column = None
+        
     def _get_from_objects(self, **modifiers):
         # this could also be [self], at the moment it doesnt matter to the Select object
         return []
@@ -1547,16 +1542,6 @@ class FromClause(Selectable):
 
         return False
 
-    def _locate_oid_column(self):
-        """Subclasses should override this to return an appropriate OID column."""
-
-        return None
-
-    def _get_oid_column(self):
-        if not hasattr(self, '_oid_column'):
-            self._oid_column = self._locate_oid_column()
-        return self._oid_column
-
     def _get_all_embedded_columns(self):
         ret = []
         class FindCols(visitors.ClauseVisitor):
@@ -1656,7 +1641,6 @@ class FromClause(Selectable):
         """A dictionary mapping an original Table-bound 
         column to a proxied column in this FromClause.
         """)
-    oid_column = property(_get_oid_column)
 
     def _export_columns(self, columns=None):
         """Initialize column collections.
@@ -2012,6 +1996,7 @@ class _Function(_CalculatedClause, FromClause):
 
     def __init__(self, name, *clauses, **kwargs):
         self.packagenames = kwargs.get('packagenames', None) or []
+        self.oid_column = None
         kwargs['operator'] = operators.comma_op
         _CalculatedClause.__init__(self, name, **kwargs)
         for c in clauses:
@@ -2190,6 +2175,7 @@ class Join(FromClause):
     def __init__(self, left, right, onclause=None, isouter = False):
         self.left = _selectable(left)
         self.right = _selectable(right).self_group()
+        self.oid_column = self.left.oid_column
         if onclause is None:
             self.onclause = self._match_primaries(self.left, self.right)
         else:
@@ -2238,8 +2224,6 @@ class Join(FromClause):
     def self_group(self, against=None):
         return _FromGrouping(self)
 
-    def _locate_oid_column(self):
-        return self.left.oid_column
 
     def _exportable_columns(self):
         return [c for c in self.left.columns] + [c for c in self.right.columns]
@@ -2393,6 +2377,10 @@ class Alias(FromClause):
             alias = '{ANON %d %s}' % (id(self), alias or 'anon')
         self.name = alias
         self.encodedname = alias.encode('ascii', 'backslashreplace')
+        if self.selectable.oid_column is not None:
+            self.oid_column = self.selectable.oid_column._make_proxy(self)
+        else:
+            self.oid_column = None
 
     def is_derived_from(self, fromclause):
         x = self.selectable
@@ -2411,12 +2399,6 @@ class Alias(FromClause):
     def _table_iterator(self):
         return self.original._table_iterator()
 
-    def _locate_oid_column(self):
-        if self.selectable.oid_column is not None:
-            return self.selectable.oid_column._make_proxy(self)
-        else:
-            return None
-
     def named_with_column(self):
         return True
 
@@ -2655,7 +2637,7 @@ class TableClause(FromClause):
         super(TableClause, self).__init__(name)
         self.name = self.fullname = name
         self.encodedname = self.name.encode('ascii', 'backslashreplace')
-        self._oid_column = _ColumnClause('oid', self, _is_oid=True)
+        self.oid_column = _ColumnClause('oid', self, _is_oid=True)
         self._export_columns(columns)
 
     def _clone(self):
@@ -2669,9 +2651,6 @@ class TableClause(FromClause):
         self._columns[c.name] = c
         c.table = self
 
-    def _locate_oid_column(self):
-        return self._oid_column
-
     def _proxy_column(self, c):
         self.append_column(c)
         return c
@@ -2844,7 +2823,8 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
                 self.selects.append(s)
 
         self._col_map = {}
-
+        self.oid_column = self.selects[0].oid_column
+        
         _SelectBaseMixin.__init__(self, **kwargs)
 
     name = property(lambda s:s.keyword + " statement")
@@ -2852,9 +2832,6 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
     def self_group(self, against=None):
         return _FromGrouping(self)
 
-    def _locate_oid_column(self):
-        return self.selects[0].oid_column
-
     def _exportable_columns(self):
         for s in self.selects:
             for c in s.c:
@@ -3165,6 +3142,7 @@ class Select(_SelectBaseMixin, FromClause):
                 return oid
         else:
             return None
+    oid_column = property(_locate_oid_column)
 
     def union(self, other, **kwargs):
         return union(self, other, **kwargs)
index ce502f03a56a65783cd05ab03324e0843edb4904..7626547fbf6845e49b54c3d361a487b7975b3049 100644 (file)
@@ -44,7 +44,7 @@ class ZooMarkTest(testing.AssertMixin):
         metadata.create_all()
         
     @testing.supported('postgres')
-    @profiling.profiled('populate', call_range=(4420, 4460), always=True)        
+    @profiling.profiled('populate', call_range=(4410, 4420), always=True)        
     def test_1a_populate(self):
         Zoo = metadata.tables['Zoo']
         Animal = metadata.tables['Animal']
@@ -120,7 +120,7 @@ class ZooMarkTest(testing.AssertMixin):
             tick = i.execute(Species='Tick', Name='Tick %d' % x, Legs=8)
     
     @testing.supported('postgres')
-    @profiling.profiled('properties', call_range=(3590, 3610), always=True)        
+    @profiling.profiled('properties', call_range=(3440, 3450), always=True)        
     def test_3_properties(self):
         Zoo = metadata.tables['Zoo']
         Animal = metadata.tables['Animal']
@@ -143,7 +143,7 @@ class ZooMarkTest(testing.AssertMixin):
             ticks = fullobject(Animal.select(Animal.c.Species=='Tick'))
     
     @testing.supported('postgres')
-    @profiling.profiled('expressions', call_range=(13790, 13800), always=True)        
+    @profiling.profiled('expressions', call_range=(13260, 13270), always=True)        
     def test_4_expressions(self):
         Zoo = metadata.tables['Zoo']
         Animal = metadata.tables['Animal']
@@ -197,7 +197,7 @@ class ZooMarkTest(testing.AssertMixin):
             assert len(fulltable(Animal.select(func.date_part('day', Animal.c.LastEscape) == 21))) == 1
     
     @testing.supported('postgres')
-    @profiling.profiled('aggregates', call_range=(1290, 1300), always=True)        
+    @profiling.profiled('aggregates', call_range=(1270, 1280), always=True)        
     def test_5_aggregates(self):
         Animal = metadata.tables['Animal']
         Zoo = metadata.tables['Zoo']
@@ -239,7 +239,7 @@ class ZooMarkTest(testing.AssertMixin):
             legs.sort()
     
     @testing.supported('postgres')
-    @profiling.profiled('editing', call_range=(1430, 1450), always=True)        
+    @profiling.profiled('editing', call_range=(1390, 1400), always=True)        
     def test_6_editing(self):
         Zoo = metadata.tables['Zoo']
         
@@ -268,7 +268,7 @@ class ZooMarkTest(testing.AssertMixin):
             assert SDZ['Founded'] == datetime.date(1935, 9, 13)
     
     @testing.supported('postgres')
-    @profiling.profiled('multiview', call_range=(3230, 3240), always=True)        
+    @profiling.profiled('multiview', call_range=(3160, 3170), always=True)        
     def test_7_multiview(self):
         Zoo = metadata.tables['Zoo']
         Animal = metadata.tables['Animal']
index 26873b25f611ae5d2f0096ce068b9c0879b7111f..2052f9e9751fa0f681401d790933e5386966b3a6 100644 (file)
@@ -160,14 +160,11 @@ class ExecutionContextWrapper(object):
             (query, params) = item
             if callable(params):
                 params = params(ctx)
-            if params is not None and isinstance(params, list) and len(params) == 1:
-                params = params[0]
+            if params is not None and not isinstance(params, list):
+                params = [params]
             
             from sqlalchemy.sql.util import ClauseParameters
-            if isinstance(ctx.compiled_parameters, ClauseParameters):
-                parameters = ctx.compiled_parameters.get_original_dict()
-            elif isinstance(ctx.compiled_parameters, list):
-                parameters = [p.get_original_dict() for p in ctx.compiled_parameters]
+            parameters = [p.get_original_dict() for p in ctx.compiled_parameters]
                     
             query = self.convert_statement(query)
             testdata.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))