From: Mike Bayer Date: Thu, 18 Mar 2010 15:48:24 +0000 (-0400) Subject: - moved most Decimal bind/result handling into types.py, out of sqlite, mysql dialects. X-Git-Tag: rel_0_6beta2~37^2~3 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=074cab9e7d01533302e84a489d740accad25476a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - moved most Decimal bind/result handling into types.py, out of sqlite, mysql dialects. - added an explicit test for [ticket:1216] - some questions remain about MSSQL - would like to simplify/remove bind handling for numerics --- diff --git a/lib/sqlalchemy/connectors/mxodbc.py b/lib/sqlalchemy/connectors/mxodbc.py index 68b88019c2..de9b18f1a4 100644 --- a/lib/sqlalchemy/connectors/mxodbc.py +++ b/lib/sqlalchemy/connectors/mxodbc.py @@ -11,7 +11,8 @@ class MxODBCConnector(Connector): supports_sane_multi_rowcount = False supports_unicode_statements = False supports_unicode_binds = False - supports_native_decimal = False + + supports_native_decimal = True @classmethod def dbapi(cls): diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py index de13a2dce1..a2da132daf 100644 --- a/lib/sqlalchemy/dialects/firebird/base.py +++ b/lib/sqlalchemy/dialects/firebird/base.py @@ -69,7 +69,7 @@ the SQLAlchemy ``returning()`` method, such as:: """ -import datetime, decimal, re +import datetime, re from sqlalchemy import schema as sa_schema from sqlalchemy import exc, types as sqltypes, sql, util diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index bd5af4e8e4..070650d97f 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -277,27 +277,13 @@ RESERVED_WORDS = set( class _MSNumeric(sqltypes.Numeric): - def result_processor(self, dialect, coltype): - if self.asdecimal: - # TODO: factor this down into the sqltypes.Numeric class, - # use dialect flags - if getattr(self, 'scale', None) is None: - # we're a "float". return a default decimal factory - return processors.to_decimal_processor_factory(decimal.Decimal) - elif dialect.supports_native_decimal: - # we're a "numeric", DBAPI will give us Decimal directly - return None - else: - # we're a "numeric", DBAPI returns floats, convert. - return processors.to_decimal_processor_factory(decimal.Decimal, self.scale) - else: - #XXX: if the DBAPI returns a float (this is likely, given the - # processor when asdecimal is True), this should be a None - # processor instead. - return processors.to_float - + def bind_processor(self, dialect): def process(value): + # TODO: this seems exceedingly complex. + # need to know exactly what tests cover this, so far + # test_types.NumericTest.test_enotation_decimal + if isinstance(value, decimal.Decimal): if value.adjusted() < 0: result = "%s0.%s%s" % ( diff --git a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py index 33c18f6d62..981e1e204b 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py +++ b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py @@ -8,7 +8,7 @@ import re from sqlalchemy.dialects.mysql.base import (MySQLDialect, MySQLExecutionContext, MySQLCompiler, MySQLIdentifierPreparer, - BIT, NUMERIC, _NumericType) + BIT) from sqlalchemy.engine import base as engine_base, default from sqlalchemy.sql import operators as sql_operators @@ -28,15 +28,6 @@ class MySQLCompiler_mysqlconnector(MySQLCompiler): def post_process_text(self, text): return text.replace('%', '%%') -class _DecimalType(_NumericType): - def result_processor(self, dialect, coltype): - if self.asdecimal: - return None - return processors.to_float - -class _myconnpyNumeric(_DecimalType, NUMERIC): - pass - class MySQLIdentifierPreparer_mysqlconnector(MySQLIdentifierPreparer): def _escape_identifier(self, value): @@ -56,6 +47,8 @@ class MySQLDialect_mysqlconnector(MySQLDialect): supports_sane_rowcount = True supports_sane_multi_rowcount = True + supports_native_decimal = True + default_paramstyle = 'format' execution_ctx_cls = MySQLExecutionContext_mysqlconnector statement_compiler = MySQLCompiler_mysqlconnector @@ -65,7 +58,6 @@ class MySQLDialect_mysqlconnector(MySQLDialect): colspecs = util.update_copy( MySQLDialect.colspecs, { - sqltypes.Numeric: _myconnpyNumeric, BIT: _myconnpyBIT, } ) diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index 038e58a4c3..9d34939a1f 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -20,11 +20,10 @@ strings, also pass ``use_unicode=0`` in the connection arguments:: create_engine('mysql:///mydb?charset=utf8&use_unicode=0') """ -import decimal import re -from sqlalchemy.dialects.mysql.base import (DECIMAL, MySQLDialect, MySQLExecutionContext, - MySQLCompiler, MySQLIdentifierPreparer, NUMERIC, _NumericType) +from sqlalchemy.dialects.mysql.base import (MySQLDialect, MySQLExecutionContext, + MySQLCompiler, MySQLIdentifierPreparer) from sqlalchemy.engine import base as engine_base, default from sqlalchemy.sql import operators as sql_operators from sqlalchemy import exc, log, schema, sql, types as sqltypes, util @@ -48,20 +47,6 @@ class MySQLCompiler_mysqldb(MySQLCompiler): return text.replace('%', '%%') -class _DecimalType(_NumericType): - def result_processor(self, dialect, coltype): - if self.asdecimal: - return None - return processors.to_float - - -class _MySQLdbNumeric(_DecimalType, NUMERIC): - pass - - -class _MySQLdbDecimal(_DecimalType, DECIMAL): - pass - class MySQLIdentifierPreparer_mysqldb(MySQLIdentifierPreparer): def _escape_identifier(self, value): @@ -74,6 +59,8 @@ class MySQLDialect_mysqldb(MySQLDialect): supports_sane_rowcount = True supports_sane_multi_rowcount = True + supports_native_decimal = True + default_paramstyle = 'format' execution_ctx_cls = MySQLExecutionContext_mysqldb statement_compiler = MySQLCompiler_mysqldb @@ -82,8 +69,6 @@ class MySQLDialect_mysqldb(MySQLDialect): colspecs = util.update_copy( MySQLDialect.colspecs, { - sqltypes.Numeric: _MySQLdbNumeric, - DECIMAL: _MySQLdbDecimal } ) diff --git a/lib/sqlalchemy/dialects/mysql/oursql.py b/lib/sqlalchemy/dialects/mysql/oursql.py index 605b39760f..f26bc4da2f 100644 --- a/lib/sqlalchemy/dialects/mysql/oursql.py +++ b/lib/sqlalchemy/dialects/mysql/oursql.py @@ -21,23 +21,16 @@ defaults to, there is a separate parameter:: create_engine('mysql+oursql:///mydb?charset=latin1') """ -import decimal import re from sqlalchemy.dialects.mysql.base import (BIT, MySQLDialect, MySQLExecutionContext, - MySQLCompiler, MySQLIdentifierPreparer, NUMERIC, _NumericType) + MySQLCompiler, MySQLIdentifierPreparer) from sqlalchemy.engine import base as engine_base, default from sqlalchemy.sql import operators as sql_operators from sqlalchemy import exc, log, schema, sql, types as sqltypes, util from sqlalchemy import processors -class _oursqlNumeric(NUMERIC): - def result_processor(self, dialect, coltype): - if self.asdecimal: - return None - return processors.to_float - class _oursqlBIT(BIT): def result_processor(self, dialect, coltype): @@ -60,7 +53,9 @@ class MySQLDialect_oursql(MySQLDialect): supports_unicode_binds = True supports_unicode_statements = True # end Py2K - + + supports_native_decimal = True + supports_sane_rowcount = True supports_sane_multi_rowcount = True execution_ctx_cls = MySQLExecutionContext_oursql @@ -69,7 +64,6 @@ class MySQLDialect_oursql(MySQLDialect): MySQLDialect.colspecs, { sqltypes.Time: sqltypes.Time, - sqltypes.Numeric: _oursqlNumeric, BIT: _oursqlBIT, } ) diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 98df8d0cb4..d7637e71b4 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -61,19 +61,6 @@ from sqlalchemy.types import BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL,\ TIMESTAMP, VARCHAR -class _NumericMixin(object): - def bind_processor(self, dialect): - if self.asdecimal: - return processors.to_str - else: - return processors.to_float - -class _SLNumeric(_NumericMixin, sqltypes.Numeric): - pass - -class _SLFloat(_NumericMixin, sqltypes.Float): - pass - class _DateTimeMixin(object): _reg = None _storage_format = None @@ -163,8 +150,6 @@ class TIME(_DateTimeMixin, sqltypes.Time): colspecs = { sqltypes.Date: DATE, sqltypes.DateTime: DATETIME, - sqltypes.Float: _SLFloat, - sqltypes.Numeric: _SLNumeric, sqltypes.Time: TIME, } diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 53f32fb2e9..3feac8f4fc 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -928,20 +928,24 @@ class Numeric(_DateAffinity, TypeEngine): return dbapi.NUMBER def bind_processor(self, dialect): - return processors.to_float + if dialect.supports_native_decimal: + return None + else: + return processors.to_float def result_processor(self, dialect, coltype): if self.asdecimal: - #XXX: use decimal from http://www.bytereef.org/libmpdec.html -# try: -# from fastdec import mpd as Decimal -# except ImportError: - if self.scale is not None: - return processors.to_decimal_processor_factory(_python_Decimal, self.scale) + if dialect.supports_native_decimal: + # we're a "numeric", DBAPI will give us Decimal directly + return None else: - return processors.to_decimal_processor_factory(_python_Decimal) + # we're a "numeric", DBAPI returns floats, convert. + return processors.to_decimal_processor_factory(_python_Decimal, self.scale) else: - return None + if dialect.supports_native_decimal: + return processors.to_float + else: + return None @util.memoized_property def _expression_adaptations(self): @@ -980,10 +984,6 @@ class Float(Numeric): def result_processor(self, dialect, coltype): if self.asdecimal: - #XXX: use decimal from http://www.bytereef.org/libmpdec.html -# try: -# from fastdec import mpd as Decimal -# except ImportError: return processors.to_decimal_processor_factory(_python_Decimal) else: return None diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 6404783a5a..507ca5a091 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -11,6 +11,8 @@ from sqlalchemy.databases import * from sqlalchemy.test.schema import Table, Column from sqlalchemy.test import * from sqlalchemy.test.util import picklers +from decimal import Decimal +from sqlalchemy.test.util import round_decimal class AdaptTest(TestBase): @@ -1084,14 +1086,6 @@ class StringTest(TestBase, AssertsExecutionResults): foo.create() foo.drop() -def _missing_decimal(): - """Python implementation supports decimals""" - try: - import decimal - return False - except ImportError: - return True - class NumericTest(TestBase, AssertsExecutionResults): @classmethod def setup_class(cls): @@ -1114,7 +1108,6 @@ class NumericTest(TestBase, AssertsExecutionResults): def teardown(self): numeric_table.delete().execute() - @testing.fails_if(_missing_decimal) def test_decimal(self): from decimal import Decimal numeric_table.insert().execute( @@ -1134,10 +1127,7 @@ class NumericTest(TestBase, AssertsExecutionResults): (2, 3.5, 5.6, Decimal("12.4"), Decimal("15.75")), ]) - @testing.fails_if(_missing_decimal) def test_precision_decimal(self): - from decimal import Decimal - from sqlalchemy.test.util import round_decimal t = Table('t', MetaData(), Column('x', Numeric(precision=18, scale=12))) t.create(testing.db) @@ -1154,13 +1144,50 @@ class NumericTest(TestBase, AssertsExecutionResults): ret = set([row[0] for row in testing.db.execute(t.select()).fetchall()]) - numbers = set(round_decimal(n, 11) for n in numbers) - ret = set(round_decimal(n, 11) for n in ret) + if testing.against('sqlite'): + numbers = set(round_decimal(n, 11) for n in numbers) + ret = set(round_decimal(n, 11) for n in ret) + else: + numbers = set(n for n in numbers) + ret = set(n for n in ret) eq_(numbers, ret) finally: t.drop(testing.db) + + @testing.fails_on('sybase', "Driver doesn't appear to handle E notation, won't accept strings") + def test_enotation_decimal(self): + """test exceedingly small decimals. + + Decimal reports values with E notation when the exponent + is greater than 6. + + """ + + t = Table('t', MetaData(), Column('x', Numeric(precision=18, scale=12))) + t.create(testing.db) + try: + numbers = set([ + decimal.Decimal('1E-2'), + decimal.Decimal('1E-3'), + decimal.Decimal('1E-4'), + decimal.Decimal('1E-5'), + decimal.Decimal('1E-6'), + decimal.Decimal('1E-7'), + decimal.Decimal('1E-8'), + ]) + + testing.db.execute(t.insert(), [{'x':x} for x in numbers]) + + ret = set([row[0] for row in testing.db.execute(t.select()).fetchall()]) + numbers = set(n for n in numbers) + ret = set(n for n in ret) + + eq_(numbers, ret) + finally: + t.drop(testing.db) + def test_decimal_fallback(self): from decimal import Decimal