]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- merged "fasttypes" branch. this branch changes the signature
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 14 Aug 2007 21:53:32 +0000 (21:53 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 14 Aug 2007 21:53:32 +0000 (21:53 +0000)
of convert_bind_param() and convert_result_value() to callable-returning
bind_processor() and result_processor() methods.  if no callable is
returned, no pre/post processing function is called.
- hooks added throughout base/sql/defaults to optimize the calling
of bind param/result processors so that method call overhead is minimized.
special cases added for executemany() scenarios such that unneeded "last row id"
logic doesn't kick in, parameters aren't excessively traversed.
- new performance tests show a combined mass-insert/mass-select test as having 68%
fewer function calls than the same test run against 0.3.
- general performance improvement of result set iteration is around 10-20%.

20 files changed:
CHANGES
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/access.py
lib/sqlalchemy/databases/informix.py
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql.py
lib/sqlalchemy/types.py
test/orm/assorted_eager.py
test/orm/unitofwork.py
test/sql/defaults.py
test/sql/select.py
test/sql/testtypes.py
test/sql/unicode.py
test/testlib/testing.py

diff --git a/CHANGES b/CHANGES
index a31ed16377d505d5f627a6472045b87509130a31..a4c67e976b559ac078b7b7f6a75145274d531776 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -1,5 +1,9 @@
 0.4.0
 - orm
+    - speed ! along with recent speedups to ResultProxy, total number of
+      function calls significantly reduced for large loads.
+      test/perf/masseagerload.py reports 0.4 as having the fewest number
+      of function calls across all SA versions (0.1, 0.2, and 0.3)
 
     - new collection_class api and implementation [ticket:213]
       collections are now instrumented via decorations rather than 
     - improved support for custom column_property() attributes which
       feature correlated subqueries...work better with eager loading now.
 
-    - along with recent speedups to ResultProxy, total number of
-      function calls significantly reduced for large loads.
-      test/perf/masseagerload.py reports 0.4 as having the fewest number
-      of function calls across all SA versions (0.1, 0.2, and 0.3)
-
     - primary key "collapse" behavior; the mapper will analyze all columns
       in its given selectable for primary key "equivalence", that is,
       columns which are equivalent via foreign key relationship or via an
     style of Hibernate
     
 - sql
+  - speed !  clause compilation as well as the mechanics of SQL constructs
+    have been streamlined and simplified to a signficant degree, for a 
+    20-30% improvement of the statement construction/compilation overhead of 
+    0.3
+    
   - all "type" keyword arguments, such as those to bindparam(), column(),
     Column(), and func.<something>(), renamed to "type_".  those objects
     still name their "type" attribute as "type".
     semantics for "__contains__" [ticket:606]
     
 - engines
+  - speed !  the mechanics of result processing and bind parameter processing
+    have been overhauled, streamlined and optimized to issue as little method 
+    calls as possible.  bench tests for mass INSERT and mass rowset iteration
+    both show 0.4 to be over twice as fast as 0.3, using 68% fewer function
+    calls.
+
   - You can now hook into the pool lifecycle and run SQL statements or
     other logic at new each DBAPI connection, pool check-out and check-in.
+
   - Connections gain a .properties collection, with contents scoped to the
     lifetime of the underlying DBAPI connection
   - removed auto_close_cursors and disallow_open_cursors arguments from Pool;
index 4d50b6a25a9674564ce9747aaffc68a0fa49b075..dd4065f3916b6628a551fd29a0739916df8766a8 100644 (file)
@@ -252,24 +252,14 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
         for a single statement execution, or one element of an executemany execution.
         """
         
-        if self.parameters is not None:
-            bindparams = self.parameters.copy()
-        else:
-            bindparams = {}
-        bindparams.update(params)
         d = sql.ClauseParameters(self.dialect, self.positiontup)
-        for b in self.binds.values():
-            name = self.bind_names[b]
-            d.set_parameter(b, b.value, name)
 
-        for key, value in bindparams.iteritems():
-            try:
-                b = self.binds[key]
-            except KeyError:
-                continue
-            name = self.bind_names[b]
-            d.set_parameter(b, value, name)
+        pd = self.parameters or {}
+        pd.update(params)
 
+        for key, bind in self.binds.iteritems():
+            d.set_parameter(bind, pd.get(key, bind.value), self.bind_names[bind])
+        
         return d
 
     params = property(lambda self:self.construct_params({}), doc="""Return the `ClauseParameters` corresponding to this compiled object.  
index 3c06822aefc86851f5e53f5fd1a5c891259a1da1..6bf8b96e969c696c2dc1d7ad62ceaa85f20947f7 100644 (file)
@@ -11,16 +11,18 @@ import sqlalchemy.engine.default as default
 
 
 class AcNumeric(types.Numeric):
-    def convert_result_value(self, value, dialect):
-        return value
-
-    def convert_bind_param(self, value, dialect):
-        if value is None:
-            # Not sure that this exception is needed
-            return value
-        else:
-            return str(value)
+    def result_processor(self, dialect):
+        return None
 
+    def bind_processor(self, dialect):
+        def process(value):
+            if value is None:
+                # Not sure that this exception is needed
+                return value
+            else:
+                return str(value)
+        return process
+        
     def get_col_spec(self):
         return "NUMERIC"
 
@@ -28,12 +30,14 @@ class AcFloat(types.Float):
     def get_col_spec(self):
         return "FLOAT"
 
-    def convert_bind_param(self, value, dialect):
+    def bind_processor(self, dialect):
         """By converting to string, we can use Decimal types round-trip."""
-        if not value is None:
-            return str(value)
-        return None
-
+        def process(value):
+            if not value is None:
+                return str(value)
+            return None
+        return process
+        
 class AcInteger(types.Integer):
     def get_col_spec(self):
         return "INTEGER"
@@ -72,11 +76,11 @@ class AcUnicode(types.Unicode):
     def get_col_spec(self):
         return "TEXT" + (self.length and ("(%d)" % self.length) or "")
 
-    def convert_bind_param(self, value, dialect):
-        return value
+    def bind_processor(self, dialect):
+        return None
 
-    def convert_result_value(self, value, dialect):
-        return value
+    def result_processor(self, dialect):
+        return None
 
 class AcChar(types.CHAR):
     def get_col_spec(self):        
@@ -90,21 +94,25 @@ class AcBoolean(types.Boolean):
     def get_col_spec(self):
         return "YESNO"
 
-    def convert_result_value(self, value, dialect):
-        if value is None:
-            return None
-        return value and True or False
-
-    def convert_bind_param(self, value, dialect):
-        if value is True:
-            return 1
-        elif value is False:
-            return 0
-        elif value is None:
-            return None
-        else:
+    def result_processor(self, dialect):
+        def process(value):
+            if value is None:
+                return None
             return value and True or False
-
+        return process
+        
+    def bind_processor(self, dialect):
+        def process(value):
+            if value is True:
+                return 1
+            elif value is False:
+                return 0
+            elif value is None:
+                return None
+            else:
+                return value and True or False
+        return process
+        
 class AcTimeStamp(types.TIMESTAMP):
     def get_col_spec(self):
         return "TIMESTAMP"
index a3ef999161bf99490c5e81fc68709ba83258fb4b..21ecf15381d3e70f5ab7666e2d94d9e2e72b25b9 100644 (file)
@@ -61,27 +61,33 @@ class InfoDateTime(sqltypes.DateTime ):
     def get_col_spec(self):
         return "DATETIME YEAR TO SECOND"
     
-    def convert_bind_param(self, value, dialect):
-        if value is not None:
-            if value.microsecond:
-                value = value.replace( microsecond = 0 )
-        return value
-
+    def bind_processor(self, dialect):
+        def process(value):
+            if value is not None:
+                if value.microsecond:
+                    value = value.replace( microsecond = 0 )
+            return value
+        return process
+        
 class InfoTime(sqltypes.Time ):
     def get_col_spec(self):
         return "DATETIME HOUR TO SECOND"
 
-    def convert_bind_param(self, value, dialect):
-        if value is not None:
-            if value.microsecond:
-                value = value.replace( microsecond = 0 )
-        return value
+    def bind_processor(self, dialect):
+        def process(value):
+            if value is not None:
+                if value.microsecond:
+                    value = value.replace( microsecond = 0 )
+            return value
+        return process
         
-    def convert_result_value(self, value, dialect):
-        if isinstance( value , datetime.datetime ):
-            return value.time()
-        else:
-            return value        
+    def result_processor(self, dialect):
+        def process(value):
+            if isinstance( value , datetime.datetime ):
+                return value.time()
+            else:
+                return value        
+        return process
         
 class InfoText(sqltypes.String):
     def get_col_spec(self):
@@ -91,36 +97,45 @@ class InfoString(sqltypes.String):
     def get_col_spec(self):
         return "VARCHAR(%(length)s)" % {'length' : self.length}
     
-    def convert_bind_param( self , value , dialect ):
-        if value == '':
-            return None
-        else:
-            return value
-
+    def bind_processor(self, dialect):
+        def process(value):
+            if value == '':
+                return None
+            else:
+                return value
+        return process
+        
 class InfoChar(sqltypes.CHAR):
     def get_col_spec(self):
         return "CHAR(%(length)s)" % {'length' : self.length}
+        
 class InfoBinary(sqltypes.Binary):
     def get_col_spec(self):
         return "BYTE"
+        
 class InfoBoolean(sqltypes.Boolean):
     default_type = 'NUM'
     def get_col_spec(self):
         return "SMALLINT"
-    def convert_result_value(self, value, dialect):
-        if value is None:
-            return None
-        return value and True or False
-    def convert_bind_param(self, value, dialect):
-        if value is True:
-            return 1
-        elif value is False:
-            return 0
-        elif value is None:
-            return None
-        else:
+        
+    def result_processor(self, dialect):
+        def process(value):
+            if value is None:
+                return None
             return value and True or False
-
+        return process
+    
+    def bind_processor(self, dialect):
+        def process(value):
+            if value is True:
+                return 1
+            elif value is False:
+                return 0
+            elif value is None:
+                return None
+            else:
+                return value and True or False
+        return process
         
 colspecs = {
     sqltypes.Integer : InfoInteger,
index 9ec0fbbc3872036794a14bd6fb284759c825d840..308a38a76989bcaef0dc03b322a92ee0b960825e 100644 (file)
@@ -47,16 +47,18 @@ from sqlalchemy.engine import default
 import operator
     
 class MSNumeric(sqltypes.Numeric):
-    def convert_result_value(self, value, dialect):
-        return value
-
-    def convert_bind_param(self, value, dialect):
-        if value is None:
-            # Not sure that this exception is needed
-            return value
-        else:
-            return str(value) 
+    def result_processor(self, dialect):
+        return None
 
+    def bind_processor(self, dialect):
+        def process(value):
+            if value is None:
+                # Not sure that this exception is needed
+                return value
+            else:
+                return str(value) 
+        return process
+        
     def get_col_spec(self):
         if self.precision is None:
             return "NUMERIC"
@@ -67,12 +69,14 @@ class MSFloat(sqltypes.Float):
     def get_col_spec(self):
         return "FLOAT(%(precision)s)" % {'precision': self.precision}
 
-    def convert_bind_param(self, value, dialect):
-        """By converting to string, we can use Decimal types round-trip."""
-        if not value is None:
-            return str(value)
-        return None
-
+    def bind_processor(self, dialect):
+        def process(value):
+            """By converting to string, we can use Decimal types round-trip."""
+            if not value is None:
+                return str(value)
+            return None
+        return process
+        
 class MSInteger(sqltypes.Integer):
     def get_col_spec(self):
         return "INTEGER"
@@ -108,57 +112,71 @@ class MSTime(sqltypes.Time):
     def get_col_spec(self):
         return "DATETIME"
 
-    def convert_bind_param(self, value, dialect):
-        if isinstance(value, datetime.datetime):
-            value = datetime.datetime.combine(self.__zero_date, value.time())
-        elif isinstance(value, datetime.time):
-            value = datetime.datetime.combine(self.__zero_date, value)
-        return value
-
-    def convert_result_value(self, value, dialect):
-        if isinstance(value, datetime.datetime):
-            return value.time()
-        elif isinstance(value, datetime.date):
-            return datetime.time(0, 0, 0)
-        return value
-
-class MSDateTime_adodbapi(MSDateTime):
-    def convert_result_value(self, value, dialect):
-        # adodbapi will return datetimes with empty time values as datetime.date() objects.
-        # Promote them back to full datetime.datetime()
-        if value and not hasattr(value, 'second'):
-            return datetime.datetime(value.year, value.month, value.day)
-        return value
-
-class MSDateTime_pyodbc(MSDateTime):
-    def convert_bind_param(self, value, dialect):
-        if value and not hasattr(value, 'second'):
-            return datetime.datetime(value.year, value.month, value.day)
-        else:
+    def bind_processor(self, dialect):
+        def process(value):
+            if isinstance(value, datetime.datetime):
+                value = datetime.datetime.combine(self.__zero_date, value.time())
+            elif isinstance(value, datetime.time):
+                value = datetime.datetime.combine(self.__zero_date, value)
             return value
-
-class MSDate_pyodbc(MSDate):
-    def convert_bind_param(self, value, dialect):
-        if value and not hasattr(value, 'second'):
-            return datetime.datetime(value.year, value.month, value.day)
-        else:
+        return process
+    
+    def result_processor(self, dialect):
+        def process(value):
+            if isinstance(value, datetime.datetime):
+                return value.time()
+            elif isinstance(value, datetime.date):
+                return datetime.time(0, 0, 0)
             return value
-
-    def convert_result_value(self, value, dialect):
-        # pyodbc returns SMALLDATETIME values as datetime.datetime(). truncate it back to datetime.date()
-        if value and hasattr(value, 'second'):
-            return value.date()
-        else:
+        return process
+        
+class MSDateTime_adodbapi(MSDateTime):
+    def result_processor(self, dialect):
+        def process(value):
+            # adodbapi will return datetimes with empty time values as datetime.date() objects.
+            # Promote them back to full datetime.datetime()
+            if value and not hasattr(value, 'second'):
+                return datetime.datetime(value.year, value.month, value.day)
             return value
-
+        return process
+        
+class MSDateTime_pyodbc(MSDateTime):
+    def bind_processor(self, dialect):
+        def process(value):
+            if value and not hasattr(value, 'second'):
+                return datetime.datetime(value.year, value.month, value.day)
+            else:
+                return value
+        return process
+        
+class MSDate_pyodbc(MSDate):
+    def bind_processor(self, dialect):
+        def process(value):
+            if value and not hasattr(value, 'second'):
+                return datetime.datetime(value.year, value.month, value.day)
+            else:
+                return value
+        return process
+    
+    def result_processor(self, dialect):
+        def process(value):
+            # pyodbc returns SMALLDATETIME values as datetime.datetime(). truncate it back to datetime.date()
+            if value and hasattr(value, 'second'):
+                return value.date()
+            else:
+                return value
+        return process
+        
 class MSDate_pymssql(MSDate):
-    def convert_result_value(self, value, dialect):
-        # pymssql will return SMALLDATETIME values as datetime.datetime(), truncate it back to datetime.date()
-        if value and hasattr(value, 'second'):
-            return value.date()
-        else:
-            return value
-
+    def result_processor(self, dialect):
+        def process(value):
+            # pymssql will return SMALLDATETIME values as datetime.datetime(), truncate it back to datetime.date()
+            if value and hasattr(value, 'second'):
+                return value.date()
+            else:
+                return value
+        return process
+        
 class MSText(sqltypes.TEXT):
     def get_col_spec(self):
         if self.dialect.text_as_varchar:
@@ -181,11 +199,11 @@ class MSNVarchar(sqltypes.Unicode):
 
 class AdoMSNVarchar(MSNVarchar):
     """overrides bindparam/result processing to not convert any unicode strings"""
-    def convert_bind_param(self, value, dialect):
-        return value
+    def bind_processor(self, dialect):
+        return None
 
-    def convert_result_value(self, value, dialect):
-        return value        
+    def result_processor(self, dialect):
+        return None
 
 class MSChar(sqltypes.CHAR):
     def get_col_spec(self):
@@ -203,20 +221,24 @@ class MSBoolean(sqltypes.Boolean):
     def get_col_spec(self):
         return "BIT"
 
-    def convert_result_value(self, value, dialect):
-        if value is None:
-            return None
-        return value and True or False
-
-    def convert_bind_param(self, value, dialect):
-        if value is True:
-            return 1
-        elif value is False:
-            return 0
-        elif value is None:
-            return None
-        else:
+    def result_processor(self, dialect):
+        def process(value):
+            if value is None:
+                return None
             return value and True or False
+        return process
+    
+    def bind_processor(self, dialect):
+        def process(value):
+            if value is True:
+                return 1
+            elif value is False:
+                return 0
+            elif value is None:
+                return None
+            else:
+                return value and True or False
+        return process
         
 class MSTimeStamp(sqltypes.TIMESTAMP):
     def get_col_spec(self):
index 01d3fa6bcd934fd7a74d55522452a9b30b56b5b1..6d6f32eadb1aa0d1b34b2d842de25e5b93facf9d 100644 (file)
@@ -294,14 +294,20 @@ class MSNumeric(sqltypes.Numeric, _NumericType):
         else:
             return self._extend("NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length})
 
-    def convert_bind_param(self, value, dialect):
-        return value
+    def bind_processor(self, dialect):
+        return None
 
-    def convert_result_value(self, value, dialect):
-        if not self.asdecimal and isinstance(value, util.decimal_type):
-            return float(value)
+    def result_processor(self, dialect):
+        if not self.asdecimal:
+            def process(value):
+                if isinstance(value, util.decimal_type):
+                    return float(value)
+                else:
+                    return value
+            return process
         else:
-            return value
+            return None
+            
 
 
 class MSDecimal(MSNumeric):
@@ -408,8 +414,8 @@ class MSFloat(sqltypes.Float, _NumericType):
         else:
             return self._extend("FLOAT")
 
-    def convert_bind_param(self, value, dialect):
-        return value
+    def bind_processor(self, dialect):
+        return None
 
 
 class MSInteger(sqltypes.Integer, _NumericType):
@@ -539,16 +545,17 @@ class MSBit(sqltypes.TypeEngine):
     def __init__(self, length=None):
         self.length = length
  
-    def convert_result_value(self, value, dialect):
+    def result_processor(self, dialect):
         """Convert a MySQL's 64 bit, variable length binary string to a long."""
-
-        if value is not None:
-            v = 0L
-            for i in map(ord, value):
-                v = v << 8 | i
-            value = v
-        return value
-
+        def process(value):
+            if value is not None:
+                v = 0L
+                for i in map(ord, value):
+                    v = v << 8 | i
+                value = v
+            return value
+        return process
+        
     def get_col_spec(self):
         if self.length is not None:
             return "BIT(%s)" % self.length
@@ -576,13 +583,14 @@ class MSTime(sqltypes.Time):
     def get_col_spec(self):
         return "TIME"
 
-    def convert_result_value(self, value, dialect):
-        # convert from a timedelta value
-        if value is not None:
-            return datetime.time(value.seconds/60/60, value.seconds/60%60, value.seconds - (value.seconds/60*60))
-        else:
-            return None
-
+    def result_processor(self, dialect):
+        def process(value):
+            # convert from a timedelta value
+            if value is not None:
+                return datetime.time(value.seconds/60/60, value.seconds/60%60, value.seconds - (value.seconds/60*60))
+            else:
+                return None
+        return process
 
 class MSTimeStamp(sqltypes.TIMESTAMP):
     """MySQL TIMESTAMP type.
@@ -930,12 +938,13 @@ class _BinaryType(sqltypes.Binary):
         else:
             return "BLOB"
 
-    def convert_result_value(self, value, dialect):
-        if value is None:
-            return None
-        else:
-            return buffer(value)
-
+    def result_processor(self, dialect):
+        def process(value):
+            if value is None:
+                return None
+            else:
+                return buffer(value)
+        return process
 
 class MSVarBinary(_BinaryType):
     """MySQL VARBINARY type, for variable length binary data."""
@@ -976,12 +985,13 @@ class MSBinary(_BinaryType):
         else:
             return "BLOB"
 
-    def convert_result_value(self, value, dialect):
-        if value is None:
-            return None
-        else:
-            return buffer(value)
-
+    def result_processor(self, dialect):
+        def process(value):
+            if value is None:
+                return None
+            else:
+                return buffer(value)
+        return process
 
 class MSBlob(_BinaryType):
     """MySQL BLOB type, for binary data up to 2^16 bytes""" 
@@ -1002,13 +1012,15 @@ class MSBlob(_BinaryType):
             return "BLOB(%d)" % self.length
         else:
             return "BLOB"
-
-    def convert_result_value(self, value, dialect):
-        if value is None:
-            return None
-        else:
-            return buffer(value)
-
+    
+    def result_processor(self, dialect):
+        def process(value):
+            if value is None:
+                return None
+            else:
+                return buffer(value)
+        return process
+        
     def __repr__(self):
         return "%s()" % self.__class__.__name__
 
@@ -1094,12 +1106,18 @@ class MSEnum(MSString):
         length = max([len(v) for v in strip_enums])
         super(MSEnum, self).__init__(length, **kw)
 
-    def convert_bind_param(self, value, engine): 
-        if self.strict and value is not None and value not in self.enums:
-            raise exceptions.InvalidRequestError('"%s" not a valid value for '
-                                                 'this enum' % value)
-        return super(MSEnum, self).convert_bind_param(value, engine)
-
+    def bind_processor(self, dialect):
+        super_convert = super(MSEnum, self).bind_processor(dialect)
+        def process(value):
+            if self.strict and value is not None and value not in self.enums:
+                raise exceptions.InvalidRequestError('"%s" not a valid value for '
+                                                     'this enum' % value)
+            if super_convert:
+                return super_convert(value)
+            else:
+                return value
+        return process
+        
     def get_col_spec(self):
         return self._extend("ENUM(%s)" % ",".join(self.__ddl_values))
 
@@ -1155,36 +1173,44 @@ class MSSet(MSString):
         length = max([len(v) for v in strip_values] + [0])
         super(MSSet, self).__init__(length, **kw)
 
-    def convert_result_value(self, value, dialect):
-        # The good news:
-        #   No ',' quoting issues- commas aren't allowed in SET values
-        # The bad news:
-        #   Plenty of driver inconsistencies here.
-        if isinstance(value, util.set_types):
-            # ..some versions convert '' to an empty set
-            if not value:
-                value.add('')
-            # ..some return sets.Set, even for pythons that have __builtin__.set
-            if not isinstance(value, util.Set):
-                value = util.Set(value)
-            return value
-        # ...and some versions return strings
-        if value is not None:
-            return util.Set(value.split(','))
-        else:
-            return value
-
-    def convert_bind_param(self, value, engine): 
-        if value is None or isinstance(value, (int, long, basestring)):
-            pass
-        else:
-            if None in value:
-                value = util.Set(value)
-                value.remove(None)
-                value.add('')
-            value = ','.join(value)
-        return super(MSSet, self).convert_bind_param(value, engine)
-
+    def result_processor(self, dialect):
+        def process(value):
+            # The good news:
+            #   No ',' quoting issues- commas aren't allowed in SET values
+            # The bad news:
+            #   Plenty of driver inconsistencies here.
+            if isinstance(value, util.set_types):
+                # ..some versions convert '' to an empty set
+                if not value:
+                    value.add('')
+                # ..some return sets.Set, even for pythons that have __builtin__.set
+                if not isinstance(value, util.Set):
+                    value = util.Set(value)
+                return value
+            # ...and some versions return strings
+            if value is not None:
+                return util.Set(value.split(','))
+            else:
+                return value
+        return process
+        
+    def bind_processor(self, dialect):
+        super_convert = super(MSSet, self).bind_processor(dialect)
+        def process(value):
+            if value is None or isinstance(value, (int, long, basestring)):
+                pass
+            else:
+                if None in value:
+                    value = util.Set(value)
+                    value.remove(None)
+                    value.add('')
+                value = ','.join(value)
+            if super_convert:
+                return super_convert(value)
+            else:
+                return value
+        return process
+        
     def get_col_spec(self):
         return self._extend("SET(%s)" % ",".join(self.__ddl_values))
 
@@ -1195,21 +1221,24 @@ class MSBoolean(sqltypes.Boolean):
     def get_col_spec(self):
         return "BOOL"
 
-    def convert_result_value(self, value, dialect):
-        if value is None:
-            return None
-        return value and True or False
-
-    def convert_bind_param(self, value, dialect):
-        if value is True:
-            return 1
-        elif value is False:
-            return 0
-        elif value is None:
-            return None
-        else:
+    def result_processor(self, dialect):
+        def process(value):
+            if value is None:
+                return None
             return value and True or False
-
+        return process
+        
+    def bind_processor(self, dialect):
+        def process(value):
+            if value is True:
+                return 1
+            elif value is False:
+                return 0
+            elif value is None:
+                return None
+            else:
+                return value and True or False
+        return process
 
 colspecs = {
     sqltypes.Integer: MSInteger,
@@ -1284,7 +1313,7 @@ class MySQLExecutionContext(default.DefaultExecutionContext):
                                re.I | re.UNICODE)
 
     def post_exec(self):
-        if self.compiled.isinsert:
+        if self.compiled.isinsert and not self.executemany:
             if (not len(self._last_inserted_ids) or
                 self._last_inserted_ids[0] is None):
                 self._last_inserted_ids = ([self.cursor.lastrowid] +
index 2c45c94e8e5c55c03cb139a72774f5a8efbe1292..520332d45e4b82ecabe5b366b68d9176541a5db9 100644 (file)
@@ -32,26 +32,31 @@ class OracleSmallInteger(sqltypes.Smallinteger):
 class OracleDate(sqltypes.Date):
     def get_col_spec(self):
         return "DATE"
-    def convert_bind_param(self, value, dialect):
-        return value
-    def convert_result_value(self, value, dialect):
-        if not isinstance(value, datetime.datetime):
-            return value
-        else:
-            return value.date()
+    def bind_processor(self, dialect):
+        return None
 
+    def result_processor(self, dialect):
+        def process(value):
+            if not isinstance(value, datetime.datetime):
+                return value
+            else:
+                return value.date()
+        return process
+        
 class OracleDateTime(sqltypes.DateTime):
     def get_col_spec(self):
         return "DATE"
         
-    def convert_result_value(self, value, dialect):
-        if value is None or isinstance(value,datetime.datetime):
-            return value
-        else:
-            # convert cx_oracle datetime object returned pre-python 2.4
-            return datetime.datetime(value.year,value.month,
-                value.day,value.hour, value.minute, value.second)
-
+    def result_processor(self, dialect):
+        def process(value):
+            if value is None or isinstance(value,datetime.datetime):
+                return value
+            else:
+                # convert cx_oracle datetime object returned pre-python 2.4
+                return datetime.datetime(value.year,value.month,
+                    value.day,value.hour, value.minute, value.second)
+        return process
+        
 # Note:
 # Oracle DATE == DATETIME
 # Oracle does not allow milliseconds in DATE
@@ -65,14 +70,15 @@ class OracleTimestamp(sqltypes.TIMESTAMP):
     def get_dbapi_type(self, dialect):
         return dialect.TIMESTAMP
 
-    def convert_result_value(self, value, dialect):
-        if value is None or isinstance(value,datetime.datetime):
-            return value
-        else:
-            # convert cx_oracle datetime object returned pre-python 2.4
-            return datetime.datetime(value.year,value.month,
-                value.day,value.hour, value.minute, value.second)
-
+    def result_processor(self, dialect):
+        def process(value):
+            if value is None or isinstance(value,datetime.datetime):
+                return value
+            else:
+                # convert cx_oracle datetime object returned pre-python 2.4
+                return datetime.datetime(value.year,value.month,
+                    value.day,value.hour, value.minute, value.second)
+        return process
 
 class OracleString(sqltypes.String):
     def get_col_spec(self):
@@ -85,15 +91,23 @@ class OracleText(sqltypes.TEXT):
     def get_col_spec(self):
         return "CLOB"
 
-    def convert_result_value(self, value, dialect):
-        if value is None:
-            return None
-        elif hasattr(value, 'read'):
-            # cx_oracle doesnt seem to be consistent with CLOB returning LOB or str
-            return super(OracleText, self).convert_result_value(value.read(), dialect)
-        else:
-            return super(OracleText, self).convert_result_value(value, dialect)
-
+    def result_processor(self, dialect):
+        super_process = super(OracleText, self).result_processor(dialect)
+        def process(value):
+            if value is None:
+                return None
+            elif hasattr(value, 'read'):
+                # cx_oracle doesnt seem to be consistent with CLOB returning LOB or str
+                if super_process:
+                    return super_process(value.read())
+                else:
+                    return value.read()
+            else:
+                if super_process:
+                    return super_process(value)
+                else:
+                    return value
+        return process
 
 class OracleRaw(sqltypes.Binary):
     def get_col_spec(self):
@@ -110,34 +124,40 @@ class OracleBinary(sqltypes.Binary):
     def get_col_spec(self):
         return "BLOB"
 
-    def convert_bind_param(self, value, dialect):
-        return value
-
-    def convert_result_value(self, value, dialect):
-        if value is None:
-            return None
-        else:
-            return value.read()
+    def bind_processor(self, dialect):
+        return None
 
+    def result_processor(self, dialect):
+        def process(value):
+            if value is None:
+                return None
+            else:
+                return value.read()
+        return process
+        
 class OracleBoolean(sqltypes.Boolean):
     def get_col_spec(self):
         return "SMALLINT"
 
-    def convert_result_value(self, value, dialect):
-        if value is None:
-            return None
-        return value and True or False
-
-    def convert_bind_param(self, value, dialect):
-        if value is True:
-            return 1
-        elif value is False:
-            return 0
-        elif value is None:
-            return None
-        else:
+    def result_processor(self, dialect):
+        def process(value):
+            if value is None:
+                return None
             return value and True or False
-
+        return process
+        
+    def bind_processor(self, dialect):
+        def process(value):
+            if value is True:
+                return 1
+            elif value is False:
+                return 0
+            elif value is None:
+                return None
+            else:
+                return value and True or False
+        return process
+        
 colspecs = {
     sqltypes.Integer : OracleInteger,
     sqltypes.Smallinteger : OracleSmallInteger,
@@ -196,7 +216,7 @@ class OracleExecutionContext(default.DefaultExecutionContext):
             if self.compiled_parameters is not None:
                  for k in self.out_parameters:
                      type = self.compiled_parameters.get_type(k)
-                     self.out_parameters[k] = type.dialect_impl(self.dialect).convert_result_value(self.out_parameters[k].getvalue(), self.dialect)
+                     self.out_parameters[k] = type.dialect_impl(self.dialect).result_processor(self.dialect)(self.out_parameters[k].getvalue())
             else:
                  for k in self.out_parameters:
                      self.out_parameters[k] = self.out_parameters[k].getvalue()
index a30832b43d1ed2686b456955487742775b89375c..e4897bba6b2f631c664ea87df6cea1e1ec9e56e8 100644 (file)
@@ -22,14 +22,19 @@ class PGNumeric(sqltypes.Numeric):
         else:
             return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}
 
-    def convert_bind_param(self, value, dialect):
-        return value
+    def bind_processor(self, dialect):
+        return None
 
-    def convert_result_value(self, value, dialect):
-        if not self.asdecimal and isinstance(value, util.decimal_type):
-            return float(value)
+    def result_processor(self, dialect):
+        if self.asdecimal:
+            return None
         else:
-            return value
+            def process(value):
+                if isinstance(value, util.decimal_type):
+                    return float(value)
+                else:
+                    return value
+            return process
         
 class PGFloat(sqltypes.Float):
     def get_col_spec(self):
@@ -98,25 +103,38 @@ class PGArray(sqltypes.TypeEngine, sqltypes.Concatenable):
         impl.__dict__.update(self.__dict__)
         impl.item_type = self.item_type.dialect_impl(dialect)
         return impl
-    def convert_bind_param(self, value, dialect):
-        if value is None:
-            return value
-        def convert_item(item):
-            if isinstance(item, (list,tuple)):
-                return [convert_item(child) for child in item]
-            else:
-                return self.item_type.convert_bind_param(item, dialect)
-        return [convert_item(item) for item in value]
-    def convert_result_value(self, value, dialect):
-        if value is None:
-            return value
-        def convert_item(item):
-            if isinstance(item, list):
-                return [convert_item(child) for child in item]
-            else:
-                return self.item_type.convert_result_value(item, dialect)
-        # Could specialcase when item_type.convert_result_value is the default identity func
-        return [convert_item(item) for item in value]
+        
+    def bind_processor(self, dialect):
+        item_proc = self.item_type.bind_processor(dialect)
+        def process(value):
+            if value is None:
+                return value
+            def convert_item(item):
+                if isinstance(item, (list,tuple)):
+                    return [convert_item(child) for child in item]
+                else:
+                    if item_proc:
+                        return item_proc(item)
+                    else:
+                        return item
+            return [convert_item(item) for item in value]
+        return process
+        
+    def result_processor(self, dialect):
+        item_proc = self.item_type.bind_processor(dialect)
+        def process(value):
+            if value is None:
+                return value
+            def convert_item(item):
+                if isinstance(item, list):
+                    return [convert_item(child) for child in item]
+                else:
+                    if item_proc:
+                        return item_proc(item)
+                    else:
+                        return item
+            return [convert_item(item) for item in value]
+        return process
     def get_col_spec(self):
         return self.item_type.get_col_spec() + '[]'
 
index 7999cc40330ba5ef78849047f974973b97a03683..3cc821a3602afcf1b7ddca20eb4660e99cf7626a 100644 (file)
@@ -32,15 +32,17 @@ class SLSmallInteger(sqltypes.Smallinteger):
         return "SMALLINT"
 
 class DateTimeMixin(object):
-    def convert_bind_param(self, value, dialect):
-        if value is not None:
-            if getattr(value, 'microsecond', None) is not None:
-                return value.strftime(self.__format__ + "." + str(value.microsecond))
+    def bind_processor(self, dialect):
+        def process(value):
+            if value is not None:
+                if getattr(value, 'microsecond', None) is not None:
+                    return value.strftime(self.__format__ + "." + str(value.microsecond))
+                else:
+                    return value.strftime(self.__format__)
             else:
-                return value.strftime(self.__format__)
-        else:
-            return None
-
+                return None
+        return process
+        
     def _cvt(self, value, dialect):
         if value is None:
             return None
@@ -57,30 +59,36 @@ class SLDateTime(DateTimeMixin,sqltypes.DateTime):
     def get_col_spec(self):
         return "TIMESTAMP"
 
-    def convert_result_value(self, value, dialect):
-        tup = self._cvt(value, dialect)
-        return tup and datetime.datetime(*tup)
-
+    def result_processor(self, dialect):
+        def process(value):
+            tup = self._cvt(value, dialect)
+            return tup and datetime.datetime(*tup)
+        return process
+        
 class SLDate(DateTimeMixin, sqltypes.Date):
     __format__ = "%Y-%m-%d"
 
     def get_col_spec(self):
         return "DATE"
 
-    def convert_result_value(self, value, dialect):
-        tup = self._cvt(value, dialect)
-        return tup and datetime.date(*tup[0:3])
-
+    def result_processor(self, dialect):
+        def process(value):
+            tup = self._cvt(value, dialect)
+            return tup and datetime.date(*tup[0:3])
+        return process
+        
 class SLTime(DateTimeMixin, sqltypes.Time):
     __format__ = "%H:%M:%S"
 
     def get_col_spec(self):
         return "TIME"
 
-    def convert_result_value(self, value, dialect):
-        tup = self._cvt(value, dialect)
-        return tup and datetime.time(*tup[3:7])
-
+    def result_processor(self, dialect):
+        def process(value):
+            tup = self._cvt(value, dialect)
+            return tup and datetime.time(*tup[3:7])
+        return process
+        
 class SLText(sqltypes.TEXT):
     def get_col_spec(self):
         return "TEXT"
@@ -101,16 +109,20 @@ class SLBoolean(sqltypes.Boolean):
     def get_col_spec(self):
         return "BOOLEAN"
 
-    def convert_bind_param(self, value, dialect):
-        if value is None:
-            return None
-        return value and 1 or 0
-
-    def convert_result_value(self, value, dialect):
-        if value is None:
-            return None
-        return value and True or False
-
+    def bind_processor(self, dialect):
+        def process(value):
+            if value is None:
+                return None
+            return value and 1 or 0
+        return process
+    
+    def result_processor(self, dialect):
+        def process(value):
+            if value is None:
+                return None
+            return value and True or False
+        return process
+        
 colspecs = {
     sqltypes.Integer : SLInteger,
     sqltypes.Smallinteger : SLSmallInteger,
@@ -150,7 +162,7 @@ def descriptor():
 
 class SQLiteExecutionContext(default.DefaultExecutionContext):
     def post_exec(self):
-        if self.compiled.isinsert:
+        if self.compiled.isinsert and not self.executemany:
             if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
                 self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:]
 
index 284be6dfef5edd41f06edf5248c8a39f208f8f11..8fe34bf3f5605b98f7f257fd8ad905f815eb1747 100644 (file)
@@ -1127,20 +1127,17 @@ class ResultProxy(object):
       col3 = row[mytable.c.mycol] # access via Column object.
 
     ResultProxy also contains a map of TypeEngine objects and will
-    invoke the appropriate ``convert_result_value()`` method before
+    invoke the appropriate ``result_processor()`` method before
     returning columns, as well as the ExecutionContext corresponding
     to the statement execution.  It provides several methods for which
     to obtain information from the underlying ExecutionContext.
     """
 
-    class AmbiguousColumn(object):
-        def __init__(self, key):
-            self.key = key
-        def dialect_impl(self, dialect):
-            return self
-        def convert_result_value(self, arg, engine):
-            raise exceptions.InvalidRequestError("Ambiguous column name '%s' in result set! try 'use_labels' option on select statement." % (self.key))
-
+    def __ambiguous_processor(self, colname):
+        def process(value):
+            raise exceptions.InvalidRequestError("Ambiguous column name '%s' in result set! try 'use_labels' option on select statement." % colname)
+        return process
+            
     def __init__(self, context):
         """ResultProxy objects are constructed via the execute() method on SQLEngine."""
         self.context = context
@@ -1185,13 +1182,13 @@ class ResultProxy(object):
                 else:
                     type = typemap.get(item[1], types.NULLTYPE)
 
-                rec = (type, type.dialect_impl(self.dialect), i)
+                rec = (type, type.dialect_impl(self.dialect).result_processor(self.dialect), i)
 
                 if rec[0] is None:
                     raise exceptions.InvalidRequestError(
                         "None for metadata " + colname)
                 if self.__props.setdefault(colname.lower(), rec) is not rec:
-                    self.__props[colname.lower()] = (type, ResultProxy.AmbiguousColumn(colname), 0)
+                    self.__props[colname.lower()] = (type, self.__ambiguous_processor(colname), 0)
                 self.__keys.append(colname)
                 self.__props[i] = rec
 
@@ -1298,7 +1295,10 @@ class ResultProxy(object):
 
     def _get_col(self, row, key):
         rec = self._key_cache[key]
-        return rec[1].convert_result_value(row[rec[2]], self.dialect)
+        if rec[1]:
+            return rec[1](row[rec[2]])
+        else:
+            return row[rec[2]]
     
     def _fetchone_impl(self):
         return self.cursor.fetchone()
index 02802452a228db1cb2a7c2f02ba3f8a426d6f4a1..ccaf080e75826892e4f534132f3af828a1ee3a4e 100644 (file)
@@ -163,12 +163,17 @@ class DefaultExecutionContext(base.ExecutionContext):
             self.statement = unicode(compiled)
             if parameters is None:
                 self.compiled_parameters = compiled.construct_params({})
+                self.executemany = False
             elif not isinstance(parameters, (list, tuple)):
                 self.compiled_parameters = compiled.construct_params(parameters)
+                self.executemany = False
             else:
                 self.compiled_parameters = [compiled.construct_params(m or {}) for m in parameters]
                 if len(self.compiled_parameters) == 1:
                     self.compiled_parameters = self.compiled_parameters[0]
+                    self.executemany = False
+                else:
+                    self.executemany = True
         elif statement is not None:
             self.typemap = self.column_labels = None
             self.parameters = self.__encode_param_keys(parameters)
@@ -206,22 +211,26 @@ class DefaultExecutionContext(base.ExecutionContext):
                 return proc(params)
 
     def __convert_compiled_params(self, parameters):
-        executemany = parameters is not None and isinstance(parameters, list)
         encode = not self.dialect.supports_unicode_statements()
         # the bind params are a CompiledParams object.  but all the DBAPI's hate
         # that object (or similar).  so convert it to a clean
         # dictionary/list/tuple of dictionary/tuple of list
         if parameters is not None:
-           if self.dialect.positional:
-                if executemany:
-                    parameters = [p.get_raw_list() for p in parameters]
+            if self.executemany:
+                processors = parameters[0].get_processors()
+            else:
+                processors = parameters.get_processors()
+
+            if self.dialect.positional:
+                if self.executemany:
+                    parameters = [p.get_raw_list(processors) for p in parameters]
                 else:
-                    parameters = parameters.get_raw_list()
-           else:
-                if executemany:
-                    parameters = [p.get_raw_dict(encode_keys=encode) for p in parameters]
+                    parameters = parameters.get_raw_list(processors)
+            else:
+                if self.executemany:
+                    parameters = [p.get_raw_dict(processors, encode_keys=encode) for p in parameters]
                 else:
-                    parameters = parameters.get_raw_dict(encode_keys=encode)
+                    parameters = parameters.get_raw_dict(processors, encode_keys=encode)
         return parameters
                 
     def is_select(self):
@@ -311,28 +320,31 @@ class DefaultExecutionContext(base.ExecutionContext):
         """generate default values for compiled insert/update statements,
         and generate last_inserted_ids() collection."""
 
-        # TODO: cleanup
         if self.isinsert:
-            if isinstance(self.compiled_parameters, list):
-                plist = self.compiled_parameters
-            else:
-                plist = [self.compiled_parameters]
             drunner = self.dialect.defaultrunner(self)
-            for param in plist:
+            if self.executemany:
+                # executemany doesn't populate last_inserted_ids()
+                firstparam = self.compiled_parameters[0]
+                processors = firstparam.get_processors()
+                for c in self.compiled.statement.table.c:
+                    if c.default is not None:
+                        params = self.compiled_parameters
+                        for param in params:
+                            if not c.key in param or param.get_original(c.key) is None:
+                                self.compiled_parameters = param
+                                newid = drunner.get_column_default(c)
+                                if newid is not None:
+                                    param.set_value(c.key, newid)
+                        self.compiled_parameters = params
+            else:
+                param = self.compiled_parameters
+                processors = param.get_processors()
                 last_inserted_ids = []
-                # check the "default" status of each column in the table
                 for c in self.compiled.statement.table.c:
-                    # check if it will be populated by a SQL clause - we'll need that
-                    # after execution.
                     if c in self.compiled.inline_params:
                         self._postfetch_cols.add(c)
                         if c.primary_key:
                             last_inserted_ids.append(None)
-                    # check if its not present at all.  see if theres a default
-                    # and fire it off, and add to bind parameters.  if
-                    # its a pk, add the value to our last_inserted_ids list,
-                    # or, if its a SQL-side default, let it fire off on the DB side, but we'll need
-                    # the SQL-generated value after execution.
                     elif not c.key in param or param.get_original(c.key) is None:
                         if isinstance(c.default, schema.PassiveDefault):
                             self._postfetch_cols.add(c)
@@ -340,32 +352,33 @@ class DefaultExecutionContext(base.ExecutionContext):
                         if newid is not None:
                             param.set_value(c.key, newid)
                             if c.primary_key:
-                                last_inserted_ids.append(param.get_processed(c.key))
+                                last_inserted_ids.append(param.get_processed(c.key, processors))
                         elif c.primary_key:
                             last_inserted_ids.append(None)
-                    # its an explicitly passed pk value - add it to
-                    # our last_inserted_ids list.
                     elif c.primary_key:
-                        last_inserted_ids.append(param.get_processed(c.key))
-                # TODO: we arent accounting for executemany() situations
-                # here (hard to do since lastrowid doesnt support it either)
+                        last_inserted_ids.append(param.get_processed(c.key, processors))
                 self._last_inserted_ids = last_inserted_ids
                 self._last_inserted_params = param
+
+
         elif self.isupdate:
-            if isinstance(self.compiled_parameters, list):
-                plist = self.compiled_parameters
-            else:
-                plist = [self.compiled_parameters]
             drunner = self.dialect.defaultrunner(self)
-            for param in plist:
-                # check the "onupdate" status of each column in the table
+            if self.executemany:
+                for c in self.compiled.statement.table.c:
+                    if c.onupdate is not None:
+                        params = self.compiled_parameters
+                        for param in params:
+                            if not c.key in param or param.get_original(c.key) is None:
+                                self.compiled_parameters = param
+                                value = drunner.get_column_onupdate(c)
+                                if value is not None:
+                                    param.set_value(c.key, value)
+                        self.compiled_parameters = params
+            else:
+                param = self.compiled_parameters
                 for c in self.compiled.statement.table.c:
-                    # it will be populated by a SQL clause - we'll need that
-                    # after execution.
                     if c in self.compiled.inline_params:
                         self._postfetch_cols.add(c)
-                    # its not in the bind parameters, and theres an "onupdate" defined for the column;
-                    # execute it and add to bind params
                     elif c.onupdate is not None and (not c.key in param or param.get_original(c.key) is None):
                         value = drunner.get_column_onupdate(c)
                         if value is not None:
index 3fc13a50dc7e0b416d7934f02bb4c68a2fd2d693..994a877bd57f367a9d26c61ae54f71d3d345bb65 100644 (file)
@@ -812,7 +812,6 @@ class ClauseParameters(object):
     """
 
     def __init__(self, dialect, positional=None):
-        super(ClauseParameters, self).__init__()
         self.dialect = dialect
         self.__binds = {}
         self.positional = positional or []
@@ -829,19 +828,31 @@ class ClauseParameters(object):
     def get_type(self, key):
         return self.__binds[key][0].type
 
-    def get_processed(self, key):
-        (bind, name, value) = self.__binds[key]
-        return bind.typeprocess(value, self.dialect)
-   
+    def get_processors(self):
+        """return a dictionary of bind 'processing' functions"""
+        return dict([
+            (key, value) for key, value in 
+            [(
+                key,
+                self.__binds[key][0].bind_processor(self.dialect)
+            ) for key in self.__binds]
+            if value is not None
+        ])
+    
+    def get_processed(self, key, processors):
+        return key in processors and processors[key](self.__binds[key][2]) or self.__binds[key][2]
+            
     def keys(self):
         return self.__binds.keys()
 
     def __iter__(self):
         return iter(self.keys())
-    def __getitem__(self, key):
-        return self.get_processed(key)
         
+    def __getitem__(self, key):
+        (bind, name, value) = self.__binds[key]
+        processor = bind.bind_processor(self.dialect)
+        return processor is not None and processor(value) or value
     def __contains__(self, key):
         return key in self.__binds
     
@@ -851,14 +862,36 @@ class ClauseParameters(object):
     def get_original_dict(self):
         return dict([(name, value) for (b, name, value) in self.__binds.values()])
 
-    def get_raw_list(self):
-        return [self.get_processed(key) for key in self.positional]
+    def get_raw_list(self, processors):
+#        (bind, name, value) = self.__binds[key]
+        return [
+            (key in processors) and
+                processors[key](self.__binds[key][2]) or
+                self.__binds[key][2]
+            for key in self.positional
+        ]
 
-    def get_raw_dict(self, encode_keys=False):
+    def get_raw_dict(self, processors, encode_keys=False):
         if encode_keys:
-            return dict([(key.encode(self.dialect.encoding), self.get_processed(key)) for key in self.keys()])
+            return dict([
+                (
+                    key.encode(self.dialect.encoding),
+                    (key in processors) and
+                        processors[key](self.__binds[key][2]) or
+                        self.__binds[key][2]
+                )
+                for key in self.keys()
+            ])
         else:
-            return dict([(key, self.get_processed(key)) for key in self.keys()])
+            return dict([
+                (
+                    key,
+                    (key in processors) and
+                        processors[key](self.__binds[key][2]) or
+                        self.__binds[key][2]
+                )
+                for key in self.keys()
+            ])
 
     def __repr__(self):
         return self.__class__.__name__ + ":" + repr(self.get_original_dict())
@@ -1995,8 +2028,8 @@ class _BindParamClause(ClauseElement, _CompareMixin):
     def _get_from_objects(self, **modifiers):
         return []
 
-    def typeprocess(self, value, dialect):
-        return self.type.dialect_impl(dialect).convert_bind_param(value, dialect)
+    def bind_processor(self, dialect):
+        return self.type.dialect_impl(dialect).bind_processor(dialect)
 
     def _compare_type(self, obj):
         if not isinstance(self.type, sqltypes.NullType):
index fe05910df40bd2cda52b597439b509683fa01c01..f3854e3e15863b00f959a642c29fedfe1cbe9a84 100644 (file)
@@ -59,12 +59,13 @@ class TypeEngine(AbstractType):
     def get_col_spec(self):
         raise NotImplementedError()
 
-    def convert_bind_param(self, value, dialect):
-        return value
-
-    def convert_result_value(self, value, dialect):
-        return value
 
+    def bind_processor(self, dialect):
+        return None
+        
+    def result_processor(self, dialect):
+        return None
+        
     def adapt(self, cls):
         return cls()
     
@@ -115,11 +116,11 @@ class TypeDecorator(AbstractType):
     def get_col_spec(self):
         return self.impl.get_col_spec()
 
-    def convert_bind_param(self, value, dialect):
-        return self.impl.convert_bind_param(value, dialect)
+    def bind_processor(self, dialect):
+        return self.impl.bind_processor(dialect)
 
-    def convert_result_value(self, value, dialect):
-        return self.impl.convert_result_value(value, dialect)
+    def result_processor(self, dialect):
+        return self.impl.result_processor(dialect)
 
     def copy(self):
         instance = self.__class__.__new__(self.__class__)
@@ -183,11 +184,6 @@ class NullType(TypeEngine):
     def get_col_spec(self):
         raise NotImplementedError()
 
-    def convert_bind_param(self, value, dialect):
-        return value
-
-    def convert_result_value(self, value, dialect):
-        return value
 NullTypeEngine = NullType
 
 class Concatenable(object):
@@ -202,11 +198,27 @@ class String(TypeEngine, Concatenable):
     def adapt(self, impltype):
         return impltype(length=self.length, convert_unicode=self.convert_unicode)
 
-    def convert_bind_param(self, value, dialect):
-        if not (self.convert_unicode or dialect.convert_unicode) or value is None or not isinstance(value, unicode):
-            return value
+    def bind_processor(self, dialect):
+        if self.convert_unicode or dialect.convert_unicode:
+            def process(value):
+                if isinstance(value, unicode):
+                    return value.encode(dialect.encoding)
+                else:
+                    return value
+            return process
         else:
-            return value.encode(dialect.encoding)
+            return None
+        
+    def result_processor(self, dialect):
+        if self.convert_unicode or dialect.convert_unicode:
+            def process(value):
+                if value is not None and not isinstance(value, unicode):
+                    return value.decode(dialect.encoding)
+                else:
+                    return value
+            return process
+        else:
+            return None
 
     def get_search_list(self):
         l = super(String, self).get_search_list()
@@ -215,11 +227,6 @@ class String(TypeEngine, Concatenable):
         else:
             return l
 
-    def convert_result_value(self, value, dialect):
-        if not (self.convert_unicode or dialect.convert_unicode) or value is None or isinstance(value, unicode):
-            return value
-        else:
-            return value.decode(dialect.encoding)
 
     def get_dbapi_type(self, dbapi):
         return dbapi.STRING
@@ -254,17 +261,24 @@ class Numeric(TypeEngine):
     def get_dbapi_type(self, dbapi):
         return dbapi.NUMBER
 
-    def convert_bind_param(self, value, dialect):
-        if value is not None:
-            return float(value)
-        else:
-            return value
-            
-    def convert_result_value(self, value, dialect):
-        if value is not None and self.asdecimal:
-            return Decimal(str(value))
+    def bind_processor(self, dialect):
+        def process(value):
+            if value is not None:
+                return float(value)
+            else:
+                return value
+        return process
+    
+    def result_processor(self, dialect):
+        if self.asdecimal:
+            def process(value):
+                if value is not None:
+                    return Decimal(str(value))
+                else:
+                    return value
+            return process
         else:
-            return value
+            return None
 
 class Float(Numeric):
     def __init__(self, precision = 10, asdecimal=False, **kwargs):
@@ -308,15 +322,14 @@ class Binary(TypeEngine):
     def __init__(self, length=None):
         self.length = length
 
-    def convert_bind_param(self, value, dialect):
-        if value is not None:
-            return dialect.dbapi.Binary(value)
-        else:
-            return None
-
-    def convert_result_value(self, value, dialect):
-        return value
-
+    def bind_processor(self, dialect):
+        def process(value):
+            if value is not None:
+                return dialect.dbapi.Binary(value)
+            else:
+                return None
+        return process
+        
     def adapt(self, impltype):
         return impltype(length=self.length)
 
@@ -332,17 +345,27 @@ class PickleType(MutableType, TypeDecorator):
         self.mutable = mutable
         super(PickleType, self).__init__()
 
-    def convert_result_value(self, value, dialect):
-        if value is None:
-            return None
-        buf = self.impl.convert_result_value(value, dialect)
-        return self.pickler.loads(str(buf))
-
-    def convert_bind_param(self, value, dialect):
-        if value is None:
-            return None
-        return self.impl.convert_bind_param(self.pickler.dumps(value, self.protocol), dialect)
-
+    def bind_processor(self, dialect):
+        impl_process = self.impl.bind_processor(dialect)
+        def process(value):
+            if value is None:
+                return None
+            if impl_process is None:
+                return self.pickler.dumps(value, self.protocol)
+            else:
+                return impl_process(self.pickler.dumps(value, self.protocol))
+        return process
+    
+    def result_processor(self, dialect):
+        impl_process = self.impl.result_processor(dialect)
+        def process(value):
+            if value is None:
+                return None
+            if impl_process is not None:
+                value = impl_process(value)
+            return self.pickler.loads(str(value))
+        return process
+        
     def copy_value(self, value):
         if self.mutable:
             return self.pickler.loads(self.pickler.dumps(value, self.protocol))
@@ -370,8 +393,8 @@ class Interval(TypeDecorator):
 
         Converting is very simple - just use epoch(zero timestamp, 01.01.1970) as
         base, so if we need to store timedelta = 1 day (24 hours) in database it
-        will be stored as DateTime = '2nd Jan 1970 00:00', see convert_bind_param
-        and convert_result_value to actual conversion code
+        will be stored as DateTime = '2nd Jan 1970 00:00', see bind_processor
+        and result_processor to actual conversion code
     """
     #Empty useless type, because at the moment of creation of instance we don't
     #know what type will be decorated - it depends on used dialect.
@@ -396,25 +419,35 @@ class Interval(TypeDecorator):
         
     def __hasNativeImpl(self,dialect):
         return dialect.__class__ in self.__supported
-            
-    def convert_bind_param(self, value, dialect):
-        if value is None:
-            return None
-        if not self.__hasNativeImpl(dialect):
-            tmpval = dt.datetime.utcfromtimestamp(0) + value
-            return self.impl.convert_bind_param(tmpval,dialect)
+    
+    def bind_processor(self, dialect):
+        impl_processor = self.impl.bind_processor(dialect)
+        if self.__hasNativeImpl(dialect):
+            return impl_processor
         else:
-            return self.impl.convert_bind_param(value,dialect)
-
-    def convert_result_value(self, value, dialect):
-        if value is None:
-            return None
-        retval = self.impl.convert_result_value(value,dialect)
-        if not self.__hasNativeImpl(dialect):
-            return retval - dt.datetime.utcfromtimestamp(0)
+            def process(value):
+                if value is None:
+                    return None
+                tmpval = dt.datetime.utcfromtimestamp(0) + value
+                if impl_processor is not None:
+                    return impl_processor(tmpval)
+                else:
+                    return tmpval
+            return process
+            
+    def result_processor(self, dialect):
+        impl_processor = self.impl.result_processor(dialect)
+        if self.__hasNativeImpl(dialect):
+            return impl_processor
         else:
-            return retval
-
+            def process(value):
+                if value is None:
+                    return None
+                if impl_processor is not None:
+                    value = impl_processor(value)
+                return value - dt.datetime.utcfromtimestamp(0)
+            return process
+            
 class FLOAT(Float):pass
 class TEXT(String):pass
 class DECIMAL(Numeric):pass
index 652186b8e6171bae5a98ac9cad92d1b8c48f5c60..ce17e8dfd333da9949bf6e4a8e1c78cf2bfe471a 100644 (file)
@@ -13,7 +13,7 @@ class EagerTest(AssertMixin):
         dbmeta = MetaData(testbase.db)
         
         # determine a literal value for "false" based on the dialect
-        false = Boolean().dialect_impl(testbase.db.dialect).convert_bind_param(False, testbase.db.dialect)
+        false = Boolean().dialect_impl(testbase.db.dialect).bind_processor(testbase.db.dialect)(False)
         
         owners = Table ( 'owners', dbmeta ,
                Column ( 'id', Integer, primary_key=True, nullable=False ),
index 0ef64746f4e149870df4c4aee749ed04e3b50837..c7a5c055a03308c811bc51f5b9a01aa4c4607567 100644 (file)
@@ -460,7 +460,7 @@ class ClauseAttributesTest(UnitOfWorkTest):
         global metadata, users_table
         metadata = MetaData(testbase.db)
         users_table = Table('users', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, Sequence('users_id_seq', optional=True), primary_key=True),
             Column('name', String(30)),
             Column('counter', Integer, default=1))
         metadata.create_all()
@@ -995,7 +995,10 @@ class SaveTest(UnitOfWorkTest):
         u = User()
         u.user_id=42
         Session.commit()
-    
+  
+    # why no support on oracle ?  because oracle doesn't save
+    # "blank" strings; it saves a single space character. 
+    @testing.unsupported('oracle') 
     def test_dont_update_blanks(self):
         mapper(User, users)
         u = User()
index 0df49ea39b39f866692125faf9bb825f3259a80e..76bd2c41fb89bf4f63c50c3c40d4dce84be66d3f 100644 (file)
@@ -9,14 +9,14 @@ import datetime
 class DefaultTest(PersistTest):
 
     def setUpAll(self):
-        global t, f, f2, ts, currenttime, metadata
+        global t, f, f2, ts, currenttime, metadata, default_generator
 
         db = testbase.db
         metadata = MetaData(db)
-        x = {'x':50}
+        default_generator = {'x':50}
         def mydefault():
-            x['x'] += 1
-            return x['x']
+            default_generator['x'] += 1
+            return default_generator['x']
 
         def myupdate_with_ctx(ctx):
             return len(ctx.compiled_parameters['col2'])
@@ -96,6 +96,7 @@ class DefaultTest(PersistTest):
         t.drop()
     
     def tearDown(self):
+        default_generator['x'] = 50
         t.delete().execute()
     
     def testargsignature(self):
@@ -125,7 +126,14 @@ class DefaultTest(PersistTest):
         t.insert().execute()
 
         ctexec = currenttime.scalar()
-        print "Currenttime "+ repr(ctexec)
+        l = t.select().execute()
+        today = datetime.date.today()
+        self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today), (52, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today), (53, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today)])
+
+    def testinsertmany(self):
+        r = t.insert().execute({}, {}, {})
+
+        ctexec = currenttime.scalar()
         l = t.select().execute()
         today = datetime.date.today()
         self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today), (52, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today), (53, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today)])
@@ -135,6 +143,25 @@ class DefaultTest(PersistTest):
         l = t.select().execute()
         self.assert_(l.fetchone()['col3'] == 50)
         
+    def testupdatemany(self):
+        t.insert().execute({}, {}, {})
+
+        t.update(t.c.col1==bindparam('pkval')).execute(
+            {'pkval':51,'col7':None, 'col8':None, 'boolcol1':False},
+        )
+        
+        
+        t.update(t.c.col1==bindparam('pkval')).execute(
+            {'pkval':51,},
+            {'pkval':52,},
+            {'pkval':53,},
+        )
+
+        l = t.select().execute()
+        ctexec = currenttime.scalar()
+        today = datetime.date.today()
+        self.assert_(l.fetchall() == [(51, 'im the update', f2, ts, ts, ctexec, False, False, 13, today), (52, 'im the update', f2, ts, ts, ctexec, True, False, 13, today), (53, 'im the update', f2, ts, ts, ctexec, True, False, 13, today)])
+        
         
     def testupdate(self):
         r = t.insert().execute()
@@ -147,7 +174,7 @@ class DefaultTest(PersistTest):
         self.assert_(l == (pk, 'im the update', f2, None, None, ctexec, True, False, 13, datetime.date.today()))
         # mysql/other db's return 0 or 1 for count(1)
         self.assert_(14 <= f2 <= 15)
-
+            
     def testupdatevalues(self):
         r = t.insert().execute()
         pk = r.last_inserted_ids()[0]
index f5932d515abea981a420eb36084846410d710ffb..865f1ec48b5f3a0a48e1bc692fbc7bfec1b0fea5 100644 (file)
@@ -902,9 +902,9 @@ EXISTS (select yay from foo where boo = lar)",
                 self.assert_compile(stmt, expected_positional_stmt, dialect=sqlite.dialect())
                 nonpositional = stmt.compile()
                 positional = stmt.compile(dialect=sqlite.dialect())
-                assert positional.get_params().get_raw_list() == expected_default_params_list
-                assert nonpositional.get_params(**test_param_dict).get_raw_dict() == expected_test_params_dict, "expected :%s got %s" % (str(expected_test_params_dict), str(nonpositional.get_params(**test_param_dict).get_raw_dict()))
-                assert positional.get_params(**test_param_dict).get_raw_list() == expected_test_params_list
+                assert positional.get_params().get_raw_list({}) == expected_default_params_list
+                assert nonpositional.get_params(**test_param_dict).get_raw_dict({}) == expected_test_params_dict, "expected :%s got %s" % (str(expected_test_params_dict), str(nonpositional.get_params(**test_param_dict).get_raw_dict()))
+                assert positional.get_params(**test_param_dict).get_raw_list({}) == expected_test_params_list
         
         # check that params() doesnt modify original statement
         s = select([table1], or_(table1.c.myid==bindparam('myid'), table2.c.otherid==bindparam('myotherid')))
index 6590330164eaa0c2f7c101106439bcf93757533c..d917b4a18cf640095b51bd74183eed340201ca7c 100644 (file)
@@ -10,28 +10,47 @@ from testlib import *
 class MyType(types.TypeEngine):
     def get_col_spec(self):
         return "VARCHAR(100)"
-    def convert_bind_param(self, value, engine):
-        return "BIND_IN"+ value
-    def convert_result_value(self, value, engine):
-        return value + "BIND_OUT"
+    def bind_processor(self, dialect):
+        def process(value):
+            return "BIND_IN"+ value
+        return process
+    def result_processor(self, dialect):
+        def process(value):
+            return value + "BIND_OUT"
+        return process
     def adapt(self, typeobj):
         return typeobj()
 
 class MyDecoratedType(types.TypeDecorator):
     impl = String
-    def convert_bind_param(self, value, dialect):
-        return "BIND_IN"+ super(MyDecoratedType, self).convert_bind_param(value, dialect)
-    def convert_result_value(self, value, dialect):
-        return super(MyDecoratedType, self).convert_result_value(value, dialect) + "BIND_OUT"
+    def bind_processor(self, dialect):
+        impl_processor = super(MyDecoratedType, self).bind_processor(dialect) or (lambda value:value)
+        def process(value):
+            return "BIND_IN"+ impl_processor(value)
+        return process
+    def result_processor(self, dialect):
+        impl_processor = super(MyDecoratedType, self).result_processor(dialect) or (lambda value:value)
+        def process(value):
+            return impl_processor(value) + "BIND_OUT"
+        return process
     def copy(self):
         return MyDecoratedType()
         
 class MyUnicodeType(types.TypeDecorator):
     impl = Unicode
-    def convert_bind_param(self, value, dialect):
-        return "UNI_BIND_IN"+ super(MyUnicodeType, self).convert_bind_param(value, dialect)
-    def convert_result_value(self, value, dialect):
-        return super(MyUnicodeType, self).convert_result_value(value, dialect) + "UNI_BIND_OUT"
+    
+    def bind_processor(self, dialect):
+        impl_processor = super(MyUnicodeType, self).bind_processor(dialect)
+        def process(value):
+            return "UNI_BIND_IN"+ impl_processor(value)
+        return process
+        
+    def result_processor(self, dialect):
+        impl_processor = super(MyUnicodeType, self).result_processor(dialect)
+        def process(value):
+            return impl_processor(value) + "UNI_BIND_OUT"
+        return process
+
     def copy(self):
         return MyUnicodeType(self.impl.length)
 
index b66d001be0e478d7e82757b5a4edee2b6f4ba8ac..19e78ed59f9110d544ddb86a3c4766982daaa2a0 100644 (file)
@@ -32,6 +32,8 @@ class UnicodeSchemaTest(PersistTest):
                    Column(u'\u6e2c\u8a66_id', Integer, primary_key=True,
                           autoincrement=False),
                    Column(u'unitable1_\u6e2c\u8a66', Integer,
+                            # lets leave these out for now so that PG tests pass, until
+                            # the test can be broken out into a pg-passing version (or we figure it out)
                           #ForeignKey(u'unitable1.\u6e2c\u8a66')
                           ),
                    Column(u'Unitéble2_b', Integer,
index 9ee201202156d2b85d90291ced061afe94bb71be..ba3670f4dcce238df390ca86da8e2a0e7cfb8b41 100644 (file)
@@ -221,7 +221,7 @@ class SQLCompileTest(PersistTest):
 
         if checkparams is not None:
             if isinstance(checkparams, list):
-                self.assert_(c.get_params().get_raw_list() == checkparams, "params dont match ")
+                self.assert_(c.get_params().get_raw_list({}) == checkparams, "params dont match ")
             else:
                 self.assert_(c.get_params().get_original_dict() == checkparams, "params dont match" + repr(c.get_params()))