]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
switching Decimal treatment in MSSQL to be pyodbc specific, added
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 19 Mar 2010 15:17:14 +0000 (11:17 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 19 Mar 2010 15:17:14 +0000 (11:17 -0400)
to connector to share between sybase/mssql.   Going
with turning decimals with very low significant digit to floats,
seems to work so far.

lib/sqlalchemy/connectors/pyodbc.py
lib/sqlalchemy/dialects/mssql/adodbapi.py
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mssql/pyodbc.py
lib/sqlalchemy/dialects/sybase/pyodbc.py

index e503135f70ed29cb8806e978f33883ce06185817..5cfe4a1921b8c5918b5c8f753fc23a471cdb1490 100644 (file)
@@ -4,6 +4,59 @@ from sqlalchemy.util import asbool
 import sys
 import re
 import urllib
+import decimal
+from sqlalchemy import processors, types as sqltypes
+
+class PyODBCNumeric(sqltypes.Numeric):
+    """Turns Decimals with adjusted() < -6 into floats."""
+    
+    def bind_processor(self, dialect):
+        super_process = super(PyODBCNumeric, self).bind_processor(dialect)
+        
+        def process(value):
+            if self.asdecimal and \
+                    isinstance(value, decimal.Decimal) and \
+                    value.adjusted() < -6:
+                return processors.to_float(value)
+            elif super_process:
+                return super_process(value)
+            else:
+                return value
+        return process
+
+    # This method turns the adjusted into a string.
+    # not sure if this has advantages over the simple float
+    # approach above.
+#    def bind_processor(self, dialect):
+#        def process(value):
+#            if isinstance(value, decimal.Decimal):
+#                if value.adjusted() < 0:
+#                    result = "%s0.%s%s" % (
+#                            (value < 0 and '-' or ''),
+#                            '0' * (abs(value.adjusted()) - 1),
+#                            "".join([str(nint) for nint in value._int]))
+#
+#                else:
+#                    if 'E' in str(value):
+#                        result = "%s%s%s" % (
+#                                (value < 0 and '-' or ''),
+#                                "".join([str(s) for s in value._int]),
+#                                "0" * (value.adjusted() - (len(value._int)-1)))
+#                    else:
+#                        if (len(value._int) - 1) > value.adjusted():
+#                            result = "%s%s.%s" % (
+#                                    (value < 0 and '-' or ''),
+#                                    "".join([str(s) for s in value._int][0:value.adjusted() + 1]),
+#                                    "".join([str(s) for s in value._int][value.adjusted() + 1:]))
+#                        else:
+#                            result = "%s%s" % (
+#                                    (value < 0 and '-' or ''),
+#                                    "".join([str(s) for s in value._int][0:value.adjusted() + 1]))
+#                return result
+#
+#            else:
+#                return value
+#        return process
 
 class PyODBCConnector(Connector):
     driver='pyodbc'
index 6ca1879d612674e6e368403ca8262468e1fd4c3d..9e12a944d7c7a37a97dd3d1b9460e166d715e34a 100644 (file)
@@ -1,4 +1,4 @@
-from sqlalchemy import types as sqltypes
+from sqlalchemy import types as sqltypes, util
 from sqlalchemy.dialects.mssql.base import MSDateTime, MSDialect
 import sys
 
@@ -25,8 +25,12 @@ class MSDialect_adodbapi(MSDialect):
         import adodbapi as module
         return module
 
-    colspecs = MSDialect.colspecs.copy()
-    colspecs[sqltypes.DateTime] = MSDateTime_adodbapi
+    colspecs = util.update_copy(
+        MSDialect.colspecs,
+        {
+            sqltypes.DateTime:MSDateTime_adodbapi
+        }
+    )
 
     def create_connect_args(self, url):
         keys = url.query
index c7713ac4d06d0195e16ecaaf1bb633cc6f5e38ba..7660fe9f720515adc75fbeb816b7fa285e63057f 100644 (file)
@@ -276,46 +276,6 @@ RESERVED_WORDS = set(
     ])
 
 
-class _MSNumeric(sqltypes.Numeric):
-    
-    def bind_processor(self, dialect):
-        def process(value):
-            # TODO: this seems exceedingly complex. 
-            # need to know exactly what tests cover this, so far
-            # test_types.NumericTest.test_enotation_decimal
-            # see the _SybNumeric type in sybase/pyodbc for possible
-            # generalized solution on pyodbc
-            if isinstance(value, decimal.Decimal):
-                if value.adjusted() < 0:
-                    result = "%s0.%s%s" % (
-                            (value < 0 and '-' or ''),
-                            '0' * (abs(value.adjusted()) - 1),
-                            "".join([str(nint) for nint in value._int]))
-
-                else:
-                    if 'E' in str(value):
-                        result = "%s%s%s" % (
-                                (value < 0 and '-' or ''),
-                                "".join([str(s) for s in value._int]),
-                                "0" * (value.adjusted() - (len(value._int)-1)))
-                    else:
-                        if (len(value._int) - 1) > value.adjusted():
-                            result = "%s%s.%s" % (
-                                    (value < 0 and '-' or ''),
-                                    "".join([str(s) for s in value._int][0:value.adjusted() + 1]),
-                                    "".join([str(s) for s in value._int][value.adjusted() + 1:]))
-                        else:
-                            result = "%s%s" % (
-                                    (value < 0 and '-' or ''),
-                                    "".join([str(s) for s in value._int][0:value.adjusted() + 1]))
-
-                return result
-
-            else:
-                return value
-
-        return process
-
 class REAL(sqltypes.Float):
     """A type for ``real`` numbers."""
 
@@ -411,27 +371,12 @@ class DATETIMEOFFSET(sqltypes.TypeEngine):
     def __init__(self, precision=None, **kwargs):
         self.precision = precision
 
-
 class _StringType(object):
     """Base for MSSQL string types."""
 
     def __init__(self, collation=None):
         self.collation = collation
 
-    def __repr__(self):
-        attributes = inspect.getargspec(self.__init__)[0][1:]
-        attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:])
-
-        params = {}
-        for attr in attributes:
-            val = getattr(self, attr)
-            if val is not None and val is not False:
-                params[attr] = val
-
-        return "%s(%s)" % (self.__class__.__name__,
-                           ', '.join(['%s=%r' % (k, params[k]) for k in params]))
-
-
 class TEXT(_StringType, sqltypes.TEXT):
     """MSSQL TEXT type, for variable-length text up to 2^31 characters."""
 
@@ -579,7 +524,6 @@ class SQL_VARIANT(sqltypes.TypeEngine):
     __visit_name__ = 'SQL_VARIANT'
 
 # old names.
-MSNumeric = _MSNumeric
 MSDateTime = _MSDateTime
 MSDate = _MSDate
 MSReal = REAL
@@ -603,13 +547,6 @@ MSSmallMoney = SMALLMONEY
 MSUniqueIdentifier = UNIQUEIDENTIFIER
 MSVariant = SQL_VARIANT
 
-colspecs = {
-    sqltypes.Numeric : _MSNumeric,
-    sqltypes.DateTime : _MSDateTime,
-    sqltypes.Date : _MSDate,
-    sqltypes.Time : TIME,
-}
-
 ischema_names = {
     'int' : INTEGER,
     'bigint': BIGINT,
@@ -1146,7 +1083,13 @@ class MSDialect(default.DefaultDialect):
     use_scope_identity = True
     max_identifier_length = 128
     schema_name = "dbo"
-    colspecs = colspecs
+
+    colspecs = {
+        sqltypes.DateTime : _MSDateTime,
+        sqltypes.Date : _MSDate,
+        sqltypes.Time : TIME,
+    }
+
     ischema_names = ischema_names
     
     supports_native_boolean = False
index 34050271fb020908e4ab468b314b9ca792831c8e..b22d742dea0064c16a27725859f08d1586e0f38d 100644 (file)
@@ -12,11 +12,12 @@ Connect strings are of the form::
 """
 
 from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect
-from sqlalchemy.connectors.pyodbc import PyODBCConnector
-from sqlalchemy import types as sqltypes
-import re
-import sys
+from sqlalchemy.connectors.pyodbc import PyODBCConnector, PyODBCNumeric
+from sqlalchemy import types as sqltypes, util
 
+class _MSNumeric_pyodbc(PyODBCNumeric):
+    pass
+    
 class MSExecutionContext_pyodbc(MSExecutionContext):
     _embedded_scope_identity = False
     
@@ -67,7 +68,14 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect):
     execution_ctx_cls = MSExecutionContext_pyodbc
 
     pyodbc_driver_name = 'SQL Server'
-
+    
+    colspecs = util.update_copy(
+        MSDialect.colspecs,
+        {
+            sqltypes.Numeric:_MSNumeric_pyodbc
+        }
+    )
+    
     def __init__(self, description_encoding='latin-1', **params):
         super(MSDialect_pyodbc, self).__init__(**params)
         self.description_encoding = description_encoding
index 61cf333da9486df654424769a25ef0181396204d..89381593048d39a1df1bd6a633135d48b8cbdc01 100644 (file)
@@ -29,27 +29,12 @@ Currently *not* supported are::
 """
 
 from sqlalchemy.dialects.sybase.base import SybaseDialect, SybaseExecutionContext
-from sqlalchemy.connectors.pyodbc import PyODBCConnector
-
-import decimal
-from sqlalchemy import processors, types as sqltypes
-
-# TODO: should this be part of pyodbc connectors ??? applies to MSSQL too ?
-class _SybNumeric(sqltypes.Numeric):
-    def bind_processor(self, dialect):
-        super_process = super(_SybNumeric, self).bind_processor(dialect)
-        
-        def process(value):
-            if self.asdecimal and \
-                    isinstance(value, decimal.Decimal) and \
-                    value.adjusted() < -6:
-                return processors.to_float(value)
-            elif super_process:
-                return super_process(value)
-            else:
-                return value
-        return process
+from sqlalchemy.connectors.pyodbc import PyODBCConnector, PyODBCNumeric
 
+from sqlalchemy import types as sqltypes, util
+
+class _SybNumeric_pyodbc(PyODBCNumeric):
+    pass
 
 class SybaseExecutionContext_pyodbc(SybaseExecutionContext):
     def set_ddl_autocommit(self, connection, value):
@@ -64,8 +49,7 @@ class SybaseDialect_pyodbc(PyODBCConnector, SybaseDialect):
     execution_ctx_cls = SybaseExecutionContext_pyodbc
 
     colspecs = {
-        sqltypes.Numeric:_SybNumeric,
-        sqltypes.Float:sqltypes.Float,
+        sqltypes.Numeric:_SybNumeric_pyodbc,
     }
 
 dialect = SybaseDialect_pyodbc