From: Mike Bayer Date: Sat, 11 Dec 2010 22:44:46 +0000 (-0500) Subject: - support for cdecimal X-Git-Tag: rel_0_7b1~179 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c691b4cbdf7424964f49ac2fd05057514e5856a3;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - support for cdecimal - add --with-cdecimal flag to tests, monkeypatches cdecimal in - fix mssql/pyodbc.py to not use private '_int' accessor in decimal conversion routines - pyodbc version 2.1.8 is needed for cdecimal in any case as previous versions also called '_int', 2.1.8 adds the same string logic as our own dialect, so that logic is skipped for modern pyodbc version - make the imports for "Decimal" consistent across the whole lib. not sure yet how we should be importing "Decimal" or what the best way forward is that would allow a clean user-invoked swap of cdecimal; for now, added docs suggesting a global monkeypatch - the two decimal libs are not compatible with each other so any chance of mixing produces serious issues. adding adapters to DBAPIs tedious and adds in-python overhead. suggestions welcome on how we should be doing Decimal/cdecimal. --- diff --git a/lib/sqlalchemy/connectors/mxodbc.py b/lib/sqlalchemy/connectors/mxodbc.py index 4c4b0b0705..1f1688a51e 100644 --- a/lib/sqlalchemy/connectors/mxodbc.py +++ b/lib/sqlalchemy/connectors/mxodbc.py @@ -15,7 +15,7 @@ For more info on mxODBC, see http://www.egenix.com/ import sys import re import warnings -from decimal import Decimal +from sqlalchemy.util.compat import decimal from sqlalchemy.connectors import Connector from sqlalchemy import types as sqltypes diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py index c5f4719426..1368f64141 100644 --- a/lib/sqlalchemy/dialects/mssql/pymssql.py +++ b/lib/sqlalchemy/dialects/mssql/pymssql.py @@ -35,7 +35,6 @@ Please consult the pymssql documentation for further information. from sqlalchemy.dialects.mssql.base import MSDialect from sqlalchemy import types as sqltypes, util, processors import re -import decimal class _MSNumeric_pymssql(sqltypes.Numeric): def result_processor(self, dialect, type_): diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py index 5bba245144..93a516706a 100644 --- a/lib/sqlalchemy/dialects/mssql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -88,7 +88,12 @@ class _MSNumeric_pyodbc(sqltypes.Numeric): """ def bind_processor(self, dialect): - super_process = super(_MSNumeric_pyodbc, self).bind_processor(dialect) + + super_process = super(_MSNumeric_pyodbc, self).\ + bind_processor(dialect) + + if not dialect._need_decimal_fix: + return super_process def process(value): if self.asdecimal and \ @@ -106,31 +111,35 @@ class _MSNumeric_pyodbc(sqltypes.Numeric): return value return process + # these routines needed for older versions of pyodbc. + # as of 2.1.8 this logic is integrated. + def _small_dec_to_string(self, value): return "%s0.%s%s" % ( (value < 0 and '-' or ''), '0' * (abs(value.adjusted()) - 1), - "".join([str(nint) for nint in value._int])) + "".join([str(nint) for nint in value.as_tuple()[1]])) def _large_dec_to_string(self, value): + _int = value.as_tuple()[1] if 'E' in str(value): result = "%s%s%s" % ( (value < 0 and '-' or ''), - "".join([str(s) for s in value._int]), - "0" * (value.adjusted() - (len(value._int)-1))) + "".join([str(s) for s in _int]), + "0" * (value.adjusted() - (len(_int)-1))) else: - if (len(value._int) - 1) > value.adjusted(): + if (len(_int) - 1) > value.adjusted(): result = "%s%s.%s" % ( (value < 0 and '-' or ''), "".join( - [str(s) for s in value._int][0:value.adjusted() + 1]), + [str(s) for s in _int][0:value.adjusted() + 1]), "".join( - [str(s) for s in value._int][value.adjusted() + 1:])) + [str(s) for s in _int][value.adjusted() + 1:])) else: result = "%s%s" % ( (value < 0 and '-' or ''), "".join( - [str(s) for s in value._int][0:value.adjusted() + 1])) + [str(s) for s in _int][0:value.adjusted() + 1])) return result @@ -200,5 +209,7 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect): self.description_encoding = description_encoding self.use_scope_identity = self.dbapi and \ hasattr(self.dbapi.Cursor, 'nextset') + self._need_decimal_fix = self.dbapi and \ + tuple(self.dbapi.version.split(".")) < (2, 1, 8) dialect = MSDialect_pyodbc diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index 87a84e514d..b7d6631388 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -121,7 +121,7 @@ from sqlalchemy.engine import base from sqlalchemy import types as sqltypes, util, exc, processors from datetime import datetime import random -from decimal import Decimal +from sqlalchemy.util.compat import decimal import re class _OracleNumeric(sqltypes.Numeric): @@ -148,10 +148,10 @@ class _OracleNumeric(sqltypes.Numeric): def to_decimal(value): if value is None: return None - elif isinstance(value, Decimal): + elif isinstance(value, decimal.Decimal): return value else: - return Decimal(fstring % value) + return decimal.Decimal(fstring % value) return to_decimal else: if self.precision is None and self.scale is None: @@ -569,15 +569,15 @@ class OracleDialect_cx_oracle(OracleDialect): self._detect_decimal = \ lambda value: _detect_decimal(value.replace(char, '.')) self._to_decimal = \ - lambda value: Decimal(value.replace(char, '.')) + lambda value: decimal.Decimal(value.replace(char, '.')) def _detect_decimal(self, value): if "." in value: - return Decimal(value) + return decimal.Decimal(value) else: return int(value) - _to_decimal = Decimal + _to_decimal = decimal.Decimal def on_connect(self): if self.cx_oracle_ver < (5,): diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py index 7b1d8e6a74..3afa06eaba 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg8000.py +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -21,10 +21,9 @@ Passing data from/to the Interval type is not supported as of yet. """ -import decimal - from sqlalchemy.engine import default from sqlalchemy import util, exc +from sqlalchemy.util.compat import decimal from sqlalchemy import processors from sqlalchemy import types as sqltypes from sqlalchemy.dialects.postgresql.base import PGDialect, \ diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index 88e6ce6709..b3f42c3306 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -86,10 +86,10 @@ The following per-statement execution options are respected: import random import re -import decimal import logging from sqlalchemy import util, exc +from sqlalchemy.util.compat import decimal from sqlalchemy import processors from sqlalchemy.engine import base, default from sqlalchemy.sql import expression diff --git a/lib/sqlalchemy/dialects/postgresql/pypostgresql.py b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py index b8f7991d5f..9abdffb6eb 100644 --- a/lib/sqlalchemy/dialects/postgresql/pypostgresql.py +++ b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py @@ -8,7 +8,6 @@ URLs are of the form ``postgresql+pypostgresql://user@password@host:port/dbname[ """ from sqlalchemy.engine import default -import decimal from sqlalchemy import util from sqlalchemy import types as sqltypes from sqlalchemy.dialects.postgresql.base import PGDialect, PGExecutionContext diff --git a/lib/sqlalchemy/dialects/sybase/pyodbc.py b/lib/sqlalchemy/dialects/sybase/pyodbc.py index 1d955a7d9c..68b16c051b 100644 --- a/lib/sqlalchemy/dialects/sybase/pyodbc.py +++ b/lib/sqlalchemy/dialects/sybase/pyodbc.py @@ -31,8 +31,8 @@ Currently *not* supported are:: from sqlalchemy.dialects.sybase.base import SybaseDialect,\ SybaseExecutionContext from sqlalchemy.connectors.pyodbc import PyODBCConnector -import decimal from sqlalchemy import types as sqltypes, util, processors +from sqlalchemy.util.compat import decimal class _SybNumeric_pyodbc(sqltypes.Numeric): """Turns Decimals with adjusted() < -6 into floats. diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 13dbc6a831..3e592ea514 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -22,14 +22,14 @@ __all__ = [ 'TypeEngine', 'TypeDecorator', 'AbstractType', 'UserDefinedType', import inspect import datetime as dt -from decimal import Decimal as _python_Decimal import codecs from sqlalchemy import exc, schema from sqlalchemy.sql import expression, operators import sys -schema.types = expression.sqltypes =sys.modules['sqlalchemy.types'] +schema.types = expression.sqltypes = sys.modules['sqlalchemy.types'] from sqlalchemy.util import pickle +from sqlalchemy.util.compat import decimal from sqlalchemy.sql.visitors import Visitable from sqlalchemy import util from sqlalchemy import processors @@ -1047,7 +1047,38 @@ class Numeric(_DateAffinity, TypeEngine): overhead, and is still subject to floating point data loss - in which case ``asdecimal=False`` will at least remove the extra conversion overhead. - + + Note that the "cdecimal" library is a high performing alternative + to Python's built-in "decimal" type, which performs very poorly + in high volume situations. SQLAlchemy 0.7 is tested against "cdecimal" + as well and supports it fully. The type is not necessarily supported by + DBAPI implementations however, most of which contain an import + for plain "decimal" in their source code, even though + some such as psycopg2 provide hooks for alternate adapters. + SQLAlchemy imports "decimal" globally as well. While the + alternate "Decimal" class can be patched into SQLA's "decimal" module, + overall the most straightforward and foolproof way to use + "cdecimal" given current support is to patch it directly + into sys.modules before anything else is imported:: + + import sys + import cdecimal + sys.modules["decimal"] = cdecimal + + While the global patch is a little ugly, it's particularly + important to use just one decimal library at a time since + Python Decimal and cdecimal Decimal objects + are not currently compatible *with each other*:: + + >>> import cdecimal + >>> import decimal + >>> decimal.Decimal("10") == cdecimal.Decimal("10") + False + + SQLAlchemy will provide more automatic support of + cdecimal if and when it becomes a standard part of Python + installations and is supported by all DBAPIs. + """ self.precision = precision self.scale = scale @@ -1085,10 +1116,10 @@ class Numeric(_DateAffinity, TypeEngine): # we're a "numeric", DBAPI returns floats, convert. if self.scale is not None: return processors.to_decimal_processor_factory( - _python_Decimal, self.scale) + decimal.Decimal, self.scale) else: return processors.to_decimal_processor_factory( - _python_Decimal) + decimal.Decimal) else: if dialect.supports_native_decimal: return processors.to_float @@ -1153,7 +1184,7 @@ class Float(Numeric): def result_processor(self, dialect, coltype): if self.asdecimal: - return processors.to_decimal_processor_factory(_python_Decimal) + return processors.to_decimal_processor_factory(decimal.Decimal) else: return None @@ -1928,9 +1959,6 @@ NULLTYPE = NullType() BOOLEANTYPE = Boolean() STRINGTYPE = String() -# using VARCHAR/NCHAR so that we dont get the genericized "String" -# type which usually resolves to TEXT/CLOB - _type_map = { str: String(), # Py3K @@ -1941,7 +1969,7 @@ _type_map = { int : Integer(), float : Numeric(), bool: BOOLEANTYPE, - _python_Decimal : Numeric(), + decimal.Decimal : Numeric(), dt.date : Date(), dt.datetime : DateTime(), dt.time : Time(), diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 961aa1f8a0..59dd9eaf08 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -171,7 +171,6 @@ if win32 or jython: else: time_func = time.time - if sys.version_info >= (2, 5): def decode_slice(slc): """decode a slice object as sent to __getitem__. @@ -188,3 +187,7 @@ if sys.version_info >= (2, 5): else: def decode_slice(slc): return (slc.start, slc.stop, slc.step) + + +import decimal + diff --git a/test/aaa_profiling/test_memusage.py b/test/aaa_profiling/test_memusage.py index 53d7ff2e4a..26b6c7df4b 100644 --- a/test/aaa_profiling/test_memusage.py +++ b/test/aaa_profiling/test_memusage.py @@ -14,7 +14,7 @@ from sqlalchemy.sql import column from sqlalchemy.processors import to_decimal_processor_factory, \ to_unicode_processor_factory from test.lib.util import gc_collect -from decimal import Decimal as _python_Decimal +from sqlalchemy.util.compat import decimal import gc import weakref from test.orm import _base @@ -580,7 +580,7 @@ class MemUsageTest(EnsureZeroed): def test_DecimalResultProcessor_process(self): @profile_memory def go(): - to_decimal_processor_factory(_python_Decimal, 10)(1.2) + to_decimal_processor_factory(decimal.Decimal, 10)(1.2) go() @testing.requires.cextensions diff --git a/test/bootstrap/config.py b/test/bootstrap/config.py index ef37e4f200..fd43d0ca75 100644 --- a/test/bootstrap/config.py +++ b/test/bootstrap/config.py @@ -48,6 +48,7 @@ def _list_dbs(*args): print "%20s\t%s" % (macro, file_config.get('db', macro)) sys.exit(0) + def _server_side_cursors(options, opt_str, value, parser): db_opts['server_side_cursors'] = True @@ -55,23 +56,15 @@ def _engine_strategy(options, opt_str, value, parser): if value: db_opts['strategy'] = value -class _ordered_map(object): - def __init__(self): - self._keys = list() - self._data = dict() - - def __setitem__(self, key, value): - if key not in self._keys: - self._keys.append(key) - self._data[key] = value - - def __iter__(self): - for key in self._keys: - yield self._data[key] +pre_configure = [] +post_configure = [] -# at one point in refactoring, modules were injecting into the config -# process. this could probably just become a list now. -post_configure = _ordered_map() +def _monkeypatch_cdecimal(options, file_config): + if options.cdecimal: + import sys + import cdecimal + sys.modules['decimal'] = cdecimal +pre_configure.append(_monkeypatch_cdecimal) def _engine_uri(options, file_config): global db_label, db_url @@ -88,7 +81,7 @@ def _engine_uri(options, file_config): raise RuntimeError( "Unknown engine. Specify --dbs for known engines.") db_url = file_config.get('db', db_label) -post_configure['engine_uri'] = _engine_uri +post_configure.append(_engine_uri) def _require(options, file_config): if not(options.require or @@ -114,20 +107,20 @@ def _require(options, file_config): if seen: continue pkg_resources.require(requirement) -post_configure['require'] = _require +post_configure.append(_require) def _engine_pool(options, file_config): if options.mockpool: from sqlalchemy import pool db_opts['poolclass'] = pool.AssertionPool -post_configure['engine_pool'] = _engine_pool +post_configure.append(_engine_pool) def _create_testing_engine(options, file_config): from test.lib import engines, testing global db db = engines.testing_engine(db_url, db_opts) testing.db = db -post_configure['create_engine'] = _create_testing_engine +post_configure.append(_create_testing_engine) def _prep_testing_database(options, file_config): from test.lib import engines @@ -149,7 +142,7 @@ def _prep_testing_database(options, file_config): md.drop_all() e.dispose() -post_configure['prep_db'] = _prep_testing_database +post_configure.append(_prep_testing_database) def _set_table_options(options, file_config): from test.lib import schema @@ -161,7 +154,7 @@ def _set_table_options(options, file_config): if options.mysql_engine: table_options['mysql_engine'] = options.mysql_engine -post_configure['table_options'] = _set_table_options +post_configure.append(_set_table_options) def _reverse_topological(options, file_config): if options.reversetop: @@ -169,5 +162,5 @@ def _reverse_topological(options, file_config): from sqlalchemy import topological from test.lib.util import RandomSet topological.set = unitofwork.set = session.set = mapper.set = dependency.set = RandomSet -post_configure['topological'] = _reverse_topological +post_configure.append(_reverse_topological) diff --git a/test/bootstrap/noseplugin.py b/test/bootstrap/noseplugin.py index 838ad0042a..5ffa8b89fd 100644 --- a/test/bootstrap/noseplugin.py +++ b/test/bootstrap/noseplugin.py @@ -15,7 +15,9 @@ from test.bootstrap import config from test.bootstrap.config import ( _create_testing_engine, _engine_pool, _engine_strategy, _engine_uri, _list_dbs, _log, _prep_testing_database, _require, _reverse_topological, _server_side_cursors, - _set_table_options, base_config, db, db_label, db_url, file_config, post_configure) + _monkeypatch_cdecimal, + _set_table_options, base_config, db, db_label, db_url, file_config, post_configure, + pre_configure) log = logging.getLogger('nose.plugins.sqlalchemy') @@ -56,6 +58,8 @@ class NoseSQLAlchemy(Plugin): opt("--reversetop", action="store_true", dest="reversetop", default=False, help="Use a random-ordering set implementation in the ORM (helps " "reveal dependency issues)") + opt("--with-cdecimal", action="store_true", dest="cdecimal", default=False, + help="Monkeypatch the cdecimal library into Python 'decimal' for all tests") opt("--unhashable", action="store_true", dest="unhashable", default=False, help="Disallow SQLAlchemy from performing a hash() on mapped test objects.") opt("--noncomparable", action="store_true", dest="noncomparable", default=False, @@ -79,7 +83,9 @@ class NoseSQLAlchemy(Plugin): def configure(self, options, conf): Plugin.configure(self, options, conf) self.options = options - + for fn in pre_configure: + fn(self.options, file_config) + def begin(self): global testing, requires, util from test.lib import testing, requires diff --git a/test/dialect/test_maxdb.py b/test/dialect/test_maxdb.py index 7d43d594bc..6ed420d5c4 100644 --- a/test/dialect/test_maxdb.py +++ b/test/dialect/test_maxdb.py @@ -4,7 +4,7 @@ from test.lib.testing import eq_ import StringIO, sys from sqlalchemy import * from sqlalchemy import exc, sql -from decimal import Decimal +from sqlalchemy.util.compat import decimal from sqlalchemy.databases import maxdb from test.lib import * @@ -40,7 +40,7 @@ class ReflectionTest(TestBase, AssertsExecutionResults): _t.create() t = Table('dectest', meta, autoload=True) - vals = [Decimal('2.2'), Decimal('23'), Decimal('2.4'), 25] + vals = [decimal.Decimal('2.2'), decimal.Decimal('23'), decimal.Decimal('2.4'), 25] cols = ['d1','d2','n1','i1'] t.insert().execute(dict(zip(cols,vals))) roundtrip = list(t.select().execute()) diff --git a/test/dialect/test_mssql.py b/test/dialect/test_mssql.py index 63aa874fd0..f3643c4df3 100644 --- a/test/dialect/test_mssql.py +++ b/test/dialect/test_mssql.py @@ -14,6 +14,7 @@ from sqlalchemy.engine import url from test.lib import * from test.lib.testing import eq_, emits_warning_on, \ assert_raises_message +from sqlalchemy.util.compat import decimal class CompileTest(TestBase, AssertsCompiledSQL): __dialect__ = mssql.dialect() @@ -1043,7 +1044,6 @@ class TypesTest(TestBase, AssertsExecutionResults, ComparesTables): @testing.fails_on_everything_except('mssql+pyodbc', 'this is some pyodbc-specific feature') def test_decimal_notation(self): - import decimal numeric_table = Table('numeric_table', metadata, Column('id', Integer, Sequence('numeric_id_seq', optional=True), primary_key=True), diff --git a/test/dialect/test_oracle.py b/test/dialect/test_oracle.py index 3a0fbac9a4..d842c7fc22 100644 --- a/test/dialect/test_oracle.py +++ b/test/dialect/test_oracle.py @@ -10,7 +10,7 @@ from test.lib.engines import testing_engine from sqlalchemy.dialects.oracle import cx_oracle, base as oracle from sqlalchemy.engine import default from sqlalchemy.util import jython -from decimal import Decimal +from sqlalchemy.util.compat import decimal import datetime import os @@ -774,12 +774,12 @@ class TypesTest(TestBase, AssertsCompiledSQL): ): for i, (val, type_) in enumerate(( (1, int), - (Decimal("5.2"), Decimal), + (decimal.Decimal("5.2"), decimal.Decimal), (6.5, float), (8.5, float), (9.5, float), (12, int), - (Decimal("14.85"), Decimal), + (decimal.Decimal("14.85"), decimal.Decimal), (15.76, float), )): eq_(row[i], val) @@ -809,8 +809,8 @@ class TypesTest(TestBase, AssertsCompiledSQL): foo.create() foo.insert().execute( - {'idata':5, 'ndata':Decimal("45.6"), 'ndata2':Decimal("45.0"), - 'nidata':Decimal('53'), 'fdata':45.68392}, + {'idata':5, 'ndata':decimal.Decimal("45.6"), 'ndata2':decimal.Decimal("45.0"), + 'nidata':decimal.Decimal('53'), 'fdata':45.68392}, ) stmt = """ @@ -825,10 +825,10 @@ class TypesTest(TestBase, AssertsCompiledSQL): row = testing.db.execute(stmt).fetchall()[0] - eq_([type(x) for x in row], [int, Decimal, Decimal, int, float]) + eq_([type(x) for x in row], [int, decimal.Decimal, decimal.Decimal, int, float]) eq_( row, - (5, Decimal('45.6'), Decimal('45'), 53, 45.683920000000001) + (5, decimal.Decimal('45.6'), decimal.Decimal('45'), 53, 45.683920000000001) ) # with a nested subquery, @@ -852,10 +852,10 @@ class TypesTest(TestBase, AssertsCompiledSQL): FROM dual """ row = testing.db.execute(stmt).fetchall()[0] - eq_([type(x) for x in row], [int, Decimal, int, int, Decimal]) + eq_([type(x) for x in row], [int, decimal.Decimal, int, int, decimal.Decimal]) eq_( row, - (5, Decimal('45.6'), 45, 53, Decimal('45.68392')) + (5, decimal.Decimal('45.6'), 45, 53, decimal.Decimal('45.68392')) ) row = testing.db.execute(text(stmt, @@ -866,9 +866,9 @@ class TypesTest(TestBase, AssertsCompiledSQL): 'nidata':Numeric(5, 0), 'fdata':Float() })).fetchall()[0] - eq_([type(x) for x in row], [int, Decimal, Decimal, Decimal, float]) + eq_([type(x) for x in row], [int, decimal.Decimal, decimal.Decimal, decimal.Decimal, float]) eq_(row, - (5, Decimal('45.6'), Decimal('45'), Decimal('53'), 45.683920000000001) + (5, decimal.Decimal('45.6'), decimal.Decimal('45'), decimal.Decimal('53'), 45.683920000000001) ) stmt = """ @@ -895,8 +895,8 @@ class TypesTest(TestBase, AssertsCompiledSQL): WHERE ROWNUM >= 0) anon_1 """ row =testing.db.execute(stmt).fetchall()[0] - eq_([type(x) for x in row], [int, Decimal, int, int, Decimal]) - eq_(row, (5, Decimal('45.6'), 45, 53, Decimal('45.68392'))) + eq_([type(x) for x in row], [int, decimal.Decimal, int, int, decimal.Decimal]) + eq_(row, (5, decimal.Decimal('45.6'), 45, 53, decimal.Decimal('45.68392'))) row = testing.db.execute(text(stmt, typemap={ @@ -906,9 +906,9 @@ class TypesTest(TestBase, AssertsCompiledSQL): 'anon_1_nidata':Numeric(5, 0), 'anon_1_fdata':Float() })).fetchall()[0] - eq_([type(x) for x in row], [int, Decimal, Decimal, Decimal, float]) + eq_([type(x) for x in row], [int, decimal.Decimal, decimal.Decimal, decimal.Decimal, float]) eq_(row, - (5, Decimal('45.6'), Decimal('45'), Decimal('53'), 45.683920000000001) + (5, decimal.Decimal('45.6'), decimal.Decimal('45'), decimal.Decimal('53'), 45.683920000000001) ) row = testing.db.execute(text(stmt, @@ -919,9 +919,9 @@ class TypesTest(TestBase, AssertsCompiledSQL): 'anon_1_nidata':Numeric(5, 0, asdecimal=False), 'anon_1_fdata':Float(asdecimal=True) })).fetchall()[0] - eq_([type(x) for x in row], [int, float, float, float, Decimal]) + eq_([type(x) for x in row], [int, float, float, float, decimal.Decimal]) eq_(row, - (5, 45.6, 45, 53, Decimal('45.68392')) + (5, 45.6, 45, 53, decimal.Decimal('45.68392')) ) @@ -1064,11 +1064,11 @@ class EuroNumericTest(TestBase): @testing.provide_metadata def test_output_type_handler(self): for stmt, exp, kw in [ - ("SELECT 0.1 FROM DUAL", Decimal("0.1"), {}), + ("SELECT 0.1 FROM DUAL", decimal.Decimal("0.1"), {}), ("SELECT 15 FROM DUAL", 15, {}), - ("SELECT CAST(15 AS NUMERIC(3, 1)) FROM DUAL", Decimal("15"), {}), - ("SELECT CAST(0.1 AS NUMERIC(5, 2)) FROM DUAL", Decimal("0.1"), {}), - ("SELECT :num FROM DUAL", Decimal("2.5"), {'num':Decimal("2.5")}) + ("SELECT CAST(15 AS NUMERIC(3, 1)) FROM DUAL", decimal.Decimal("15"), {}), + ("SELECT CAST(0.1 AS NUMERIC(5, 2)) FROM DUAL", decimal.Decimal("0.1"), {}), + ("SELECT :num FROM DUAL", decimal.Decimal("2.5"), {'num':decimal.Decimal("2.5")}) ]: test_exp = self.engine.scalar(stmt, **kw) eq_( diff --git a/test/dialect/test_postgresql.py b/test/dialect/test_postgresql.py index bff1fba68a..cfccb9bb1c 100644 --- a/test/dialect/test_postgresql.py +++ b/test/dialect/test_postgresql.py @@ -2,12 +2,12 @@ from test.lib.testing import eq_, assert_raises, assert_raises_message from test.lib import engines import datetime -import decimal from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy import exc, schema, types from sqlalchemy.dialects.postgresql import base as postgresql from sqlalchemy.engine.strategies import MockEngineStrategy +from sqlalchemy.util.compat import decimal from test.lib import * from test.lib.util import round_decimal from sqlalchemy.sql import table, column @@ -480,7 +480,7 @@ class NumericInterpretationTest(TestBase): def test_numeric_codes(self): from sqlalchemy.dialects.postgresql import pg8000, psycopg2, base - from decimal import Decimal + from sqlalchemy.util.compat import decimal for dialect in (pg8000.dialect(), psycopg2.dialect()): @@ -491,7 +491,7 @@ class NumericInterpretationTest(TestBase): val = 23.7 if proc is not None: val = proc(val) - assert val in (23.7, Decimal("23.7")) + assert val in (23.7, decimal.Decimal("23.7")) class InsertTest(TestBase, AssertsExecutionResults): diff --git a/test/lib/util.py b/test/lib/util.py index b8cc05a818..4c9892852f 100644 --- a/test/lib/util.py +++ b/test/lib/util.py @@ -1,4 +1,5 @@ from sqlalchemy.util import jython, defaultdict, decorator +from sqlalchemy.util.compat import decimal import gc import time @@ -44,8 +45,6 @@ def round_decimal(value, prec): if isinstance(value, float): return round(value, prec) - import decimal - # can also use shift() here but that is 2.6 only return (value * decimal.Decimal("1" + "0" * prec)).to_integral(decimal.ROUND_FLOOR) / \ pow(10, prec) diff --git a/test/perf/stress_all.py b/test/perf/stress_all.py index ad074ee533..a19be95795 100644 --- a/test/perf/stress_all.py +++ b/test/perf/stress_all.py @@ -1,6 +1,6 @@ # -*- encoding: utf8 -*- from datetime import * -from decimal import Decimal +from sqlalchemy.util.compat import decimal #from fastdec import mpd as Decimal from cPickle import dumps, loads diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index 396eaaf9ba..0fb2ca5f7f 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -10,7 +10,7 @@ from sqlalchemy import types as sqltypes from test.lib import * from sqlalchemy.sql.functions import GenericFunction from test.lib.testing import eq_ -from decimal import Decimal as _python_Decimal +from sqlalchemy.util.compat import decimal from test.lib import testing from sqlalchemy.databases import * @@ -107,7 +107,7 @@ class CompileTest(TestBase, AssertsCompiledSQL): ((datetime.date(2007, 10, 5), datetime.date(2005, 10, 15)), sqltypes.Date), ((3, 5), sqltypes.Integer), - ((_python_Decimal(3), _python_Decimal(5)), sqltypes.Numeric), + ((decimal.Decimal(3), decimal.Decimal(5)), sqltypes.Numeric), (("foo", "bar"), sqltypes.String), ((datetime.datetime(2007, 10, 5, 8, 3, 34), datetime.datetime(2005, 10, 15, 14, 45, 33)), sqltypes.DateTime) diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 73d99ac6a4..3d9be543c7 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -11,7 +11,7 @@ from sqlalchemy.databases import * from test.lib.schema import Table, Column from test.lib import * from test.lib.util import picklers -from decimal import Decimal +from sqlalchemy.util.compat import decimal from test.lib.util import round_decimal @@ -1262,8 +1262,8 @@ class NumericTest(TestBase): def test_numeric_as_decimal(self): self._do_test( Numeric(precision=8, scale=4), - [15.7563, Decimal("15.7563"), None], - [Decimal("15.7563"), None], + [15.7563, decimal.Decimal("15.7563"), None], + [decimal.Decimal("15.7563"), None], ) def test_numeric_as_float(self): @@ -1274,7 +1274,7 @@ class NumericTest(TestBase): self._do_test( Numeric(precision=8, scale=4, asdecimal=False), - [15.7563, Decimal("15.7563"), None], + [15.7563, decimal.Decimal("15.7563"), None], [15.7563, None], filter_ = filter_ ) @@ -1282,15 +1282,15 @@ class NumericTest(TestBase): def test_float_as_decimal(self): self._do_test( Float(precision=8, asdecimal=True), - [15.7563, Decimal("15.7563"), None], - [Decimal("15.7563"), None], + [15.7563, decimal.Decimal("15.7563"), None], + [decimal.Decimal("15.7563"), None], filter_ = lambda n:n is not None and round(n, 5) or None ) def test_float_as_float(self): self._do_test( Float(precision=8), - [15.7563, Decimal("15.7563")], + [15.7563, decimal.Decimal("15.7563")], [15.7563], filter_ = lambda n:n is not None and round(n, 5) or None ) @@ -1390,18 +1390,18 @@ class NumericRawSQLTest(TestBase): @testing.fails_on('sqlite', "Doesn't provide Decimal results natively") @testing.provide_metadata def test_decimal_fp(self): - t = self._fixture(metadata, Numeric(10, 5), Decimal("45.5")) + t = self._fixture(metadata, Numeric(10, 5), decimal.Decimal("45.5")) val = testing.db.execute("select val from t").scalar() - assert isinstance(val, Decimal) - eq_(val, Decimal("45.5")) + assert isinstance(val, decimal.Decimal) + eq_(val, decimal.Decimal("45.5")) @testing.fails_on('sqlite', "Doesn't provide Decimal results natively") @testing.provide_metadata def test_decimal_int(self): - t = self._fixture(metadata, Numeric(10, 5), Decimal("45")) + t = self._fixture(metadata, Numeric(10, 5), decimal.Decimal("45")) val = testing.db.execute("select val from t").scalar() - assert isinstance(val, Decimal) - eq_(val, Decimal("45")) + assert isinstance(val, decimal.Decimal) + eq_(val, decimal.Decimal("45")) @testing.provide_metadata def test_ints(self):