From: Mike Bayer Date: Fri, 19 Mar 2010 15:17:14 +0000 (-0400) Subject: switching Decimal treatment in MSSQL to be pyodbc specific, added X-Git-Tag: rel_0_6beta2~27 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5be0d3133bb3591ca31e2da0a01fb3d3038aa9f8;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git switching Decimal treatment in MSSQL to be pyodbc specific, added to connector to share between sybase/mssql. Going with turning decimals with very low significant digit to floats, seems to work so far. --- diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py index e503135f70..5cfe4a1921 100644 --- a/lib/sqlalchemy/connectors/pyodbc.py +++ b/lib/sqlalchemy/connectors/pyodbc.py @@ -4,6 +4,59 @@ from sqlalchemy.util import asbool import sys import re import urllib +import decimal +from sqlalchemy import processors, types as sqltypes + +class PyODBCNumeric(sqltypes.Numeric): + """Turns Decimals with adjusted() < -6 into floats.""" + + def bind_processor(self, dialect): + super_process = super(PyODBCNumeric, self).bind_processor(dialect) + + def process(value): + if self.asdecimal and \ + isinstance(value, decimal.Decimal) and \ + value.adjusted() < -6: + return processors.to_float(value) + elif super_process: + return super_process(value) + else: + return value + return process + + # This method turns the adjusted into a string. + # not sure if this has advantages over the simple float + # approach above. +# def bind_processor(self, dialect): +# def process(value): +# if isinstance(value, decimal.Decimal): +# if value.adjusted() < 0: +# result = "%s0.%s%s" % ( +# (value < 0 and '-' or ''), +# '0' * (abs(value.adjusted()) - 1), +# "".join([str(nint) for nint in value._int])) +# +# else: +# if 'E' in str(value): +# result = "%s%s%s" % ( +# (value < 0 and '-' or ''), +# "".join([str(s) for s in value._int]), +# "0" * (value.adjusted() - (len(value._int)-1))) +# else: +# if (len(value._int) - 1) > value.adjusted(): +# result = "%s%s.%s" % ( +# (value < 0 and '-' or ''), +# "".join([str(s) for s in value._int][0:value.adjusted() + 1]), +# "".join([str(s) for s in value._int][value.adjusted() + 1:])) +# else: +# result = "%s%s" % ( +# (value < 0 and '-' or ''), +# "".join([str(s) for s in value._int][0:value.adjusted() + 1])) +# return result +# +# else: +# return value +# return process class PyODBCConnector(Connector): driver='pyodbc' diff --git a/lib/sqlalchemy/dialects/mssql/adodbapi.py b/lib/sqlalchemy/dialects/mssql/adodbapi.py index 6ca1879d61..9e12a944d7 100644 --- a/lib/sqlalchemy/dialects/mssql/adodbapi.py +++ b/lib/sqlalchemy/dialects/mssql/adodbapi.py @@ -1,4 +1,4 @@ -from sqlalchemy import types as sqltypes +from sqlalchemy import types as sqltypes, util from sqlalchemy.dialects.mssql.base import MSDateTime, MSDialect import sys @@ -25,8 +25,12 @@ class MSDialect_adodbapi(MSDialect): import adodbapi as module return module - colspecs = MSDialect.colspecs.copy() - colspecs[sqltypes.DateTime] = MSDateTime_adodbapi + colspecs = util.update_copy( + MSDialect.colspecs, + { + sqltypes.DateTime:MSDateTime_adodbapi + } + ) def create_connect_args(self, url): keys = url.query diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index c7713ac4d0..7660fe9f72 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -276,46 +276,6 @@ RESERVED_WORDS = set( ]) -class _MSNumeric(sqltypes.Numeric): - - def bind_processor(self, dialect): - def process(value): - # TODO: this seems exceedingly complex. - # need to know exactly what tests cover this, so far - # test_types.NumericTest.test_enotation_decimal - # see the _SybNumeric type in sybase/pyodbc for possible - # generalized solution on pyodbc - if isinstance(value, decimal.Decimal): - if value.adjusted() < 0: - result = "%s0.%s%s" % ( - (value < 0 and '-' or ''), - '0' * (abs(value.adjusted()) - 1), - "".join([str(nint) for nint in value._int])) - - else: - if 'E' in str(value): - result = "%s%s%s" % ( - (value < 0 and '-' or ''), - "".join([str(s) for s in value._int]), - "0" * (value.adjusted() - (len(value._int)-1))) - else: - if (len(value._int) - 1) > value.adjusted(): - result = "%s%s.%s" % ( - (value < 0 and '-' or ''), - "".join([str(s) for s in value._int][0:value.adjusted() + 1]), - "".join([str(s) for s in value._int][value.adjusted() + 1:])) - else: - result = "%s%s" % ( - (value < 0 and '-' or ''), - "".join([str(s) for s in value._int][0:value.adjusted() + 1])) - - return result - - else: - return value - - return process - class REAL(sqltypes.Float): """A type for ``real`` numbers.""" @@ -411,27 +371,12 @@ class DATETIMEOFFSET(sqltypes.TypeEngine): def __init__(self, precision=None, **kwargs): self.precision = precision - class _StringType(object): """Base for MSSQL string types.""" def __init__(self, collation=None): self.collation = collation - def __repr__(self): - attributes = inspect.getargspec(self.__init__)[0][1:] - attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:]) - - params = {} - for attr in attributes: - val = getattr(self, attr) - if val is not None and val is not False: - params[attr] = val - - return "%s(%s)" % (self.__class__.__name__, - ', '.join(['%s=%r' % (k, params[k]) for k in params])) - - class TEXT(_StringType, sqltypes.TEXT): """MSSQL TEXT type, for variable-length text up to 2^31 characters.""" @@ -579,7 +524,6 @@ class SQL_VARIANT(sqltypes.TypeEngine): __visit_name__ = 'SQL_VARIANT' # old names. -MSNumeric = _MSNumeric MSDateTime = _MSDateTime MSDate = _MSDate MSReal = REAL @@ -603,13 +547,6 @@ MSSmallMoney = SMALLMONEY MSUniqueIdentifier = UNIQUEIDENTIFIER MSVariant = SQL_VARIANT -colspecs = { - sqltypes.Numeric : _MSNumeric, - sqltypes.DateTime : _MSDateTime, - sqltypes.Date : _MSDate, - sqltypes.Time : TIME, -} - ischema_names = { 'int' : INTEGER, 'bigint': BIGINT, @@ -1146,7 +1083,13 @@ class MSDialect(default.DefaultDialect): use_scope_identity = True max_identifier_length = 128 schema_name = "dbo" - colspecs = colspecs + + colspecs = { + sqltypes.DateTime : _MSDateTime, + sqltypes.Date : _MSDate, + sqltypes.Time : TIME, + } + ischema_names = ischema_names supports_native_boolean = False diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py index 34050271fb..b22d742dea 100644 --- a/lib/sqlalchemy/dialects/mssql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -12,11 +12,12 @@ Connect strings are of the form:: """ from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect -from sqlalchemy.connectors.pyodbc import PyODBCConnector -from sqlalchemy import types as sqltypes -import re -import sys +from sqlalchemy.connectors.pyodbc import PyODBCConnector, PyODBCNumeric +from sqlalchemy import types as sqltypes, util +class _MSNumeric_pyodbc(PyODBCNumeric): + pass + class MSExecutionContext_pyodbc(MSExecutionContext): _embedded_scope_identity = False @@ -67,7 +68,14 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect): execution_ctx_cls = MSExecutionContext_pyodbc pyodbc_driver_name = 'SQL Server' - + + colspecs = util.update_copy( + MSDialect.colspecs, + { + sqltypes.Numeric:_MSNumeric_pyodbc + } + ) + def __init__(self, description_encoding='latin-1', **params): super(MSDialect_pyodbc, self).__init__(**params) self.description_encoding = description_encoding diff --git a/lib/sqlalchemy/dialects/sybase/pyodbc.py b/lib/sqlalchemy/dialects/sybase/pyodbc.py index 61cf333da9..8938159304 100644 --- a/lib/sqlalchemy/dialects/sybase/pyodbc.py +++ b/lib/sqlalchemy/dialects/sybase/pyodbc.py @@ -29,27 +29,12 @@ Currently *not* supported are:: """ from sqlalchemy.dialects.sybase.base import SybaseDialect, SybaseExecutionContext -from sqlalchemy.connectors.pyodbc import PyODBCConnector - -import decimal -from sqlalchemy import processors, types as sqltypes - -# TODO: should this be part of pyodbc connectors ??? applies to MSSQL too ? -class _SybNumeric(sqltypes.Numeric): - def bind_processor(self, dialect): - super_process = super(_SybNumeric, self).bind_processor(dialect) - - def process(value): - if self.asdecimal and \ - isinstance(value, decimal.Decimal) and \ - value.adjusted() < -6: - return processors.to_float(value) - elif super_process: - return super_process(value) - else: - return value - return process +from sqlalchemy.connectors.pyodbc import PyODBCConnector, PyODBCNumeric +from sqlalchemy import types as sqltypes, util + +class _SybNumeric_pyodbc(PyODBCNumeric): + pass class SybaseExecutionContext_pyodbc(SybaseExecutionContext): def set_ddl_autocommit(self, connection, value): @@ -64,8 +49,7 @@ class SybaseDialect_pyodbc(PyODBCConnector, SybaseDialect): execution_ctx_cls = SybaseExecutionContext_pyodbc colspecs = { - sqltypes.Numeric:_SybNumeric, - sqltypes.Float:sqltypes.Float, + sqltypes.Numeric:_SybNumeric_pyodbc, } dialect = SybaseDialect_pyodbc