]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- support for cdecimal
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 11 Dec 2010 22:44:46 +0000 (17:44 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 11 Dec 2010 22:44:46 +0000 (17:44 -0500)
- add --with-cdecimal flag to tests, monkeypatches cdecimal in
- fix mssql/pyodbc.py to not use private '_int' accessor in decimal conversion
routines
- pyodbc version 2.1.8 is needed for cdecimal in any case as
previous versions also called '_int', 2.1.8 adds the same string
logic as our own dialect, so that logic is skipped for modern
pyodbc version
- make the imports for "Decimal" consistent across the whole lib.  not sure
yet how we should be importing "Decimal" or what the best way forward
is that would allow a clean user-invoked swap of cdecimal; for now,
added docs suggesting a global monkeypatch - the two decimal libs
are not compatible with each other so any chance of mixing produces
serious issues.  adding adapters to DBAPIs tedious and adds in-python
overhead.  suggestions welcome on how we should be doing
Decimal/cdecimal.

21 files changed:
lib/sqlalchemy/connectors/mxodbc.py
lib/sqlalchemy/dialects/mssql/pymssql.py
lib/sqlalchemy/dialects/mssql/pyodbc.py
lib/sqlalchemy/dialects/oracle/cx_oracle.py
lib/sqlalchemy/dialects/postgresql/pg8000.py
lib/sqlalchemy/dialects/postgresql/psycopg2.py
lib/sqlalchemy/dialects/postgresql/pypostgresql.py
lib/sqlalchemy/dialects/sybase/pyodbc.py
lib/sqlalchemy/types.py
lib/sqlalchemy/util/compat.py
test/aaa_profiling/test_memusage.py
test/bootstrap/config.py
test/bootstrap/noseplugin.py
test/dialect/test_maxdb.py
test/dialect/test_mssql.py
test/dialect/test_oracle.py
test/dialect/test_postgresql.py
test/lib/util.py
test/perf/stress_all.py
test/sql/test_functions.py
test/sql/test_types.py

index 4c4b0b07058ff937ed7fc6050a5a598f7df1d228..1f1688a51ebd340cb0d385514e4bc9d31b2d99c2 100644 (file)
@@ -15,7 +15,7 @@ For more info on mxODBC, see http://www.egenix.com/
 import sys
 import re
 import warnings
-from decimal import Decimal
+from sqlalchemy.util.compat import decimal
 
 from sqlalchemy.connectors import Connector
 from sqlalchemy import types as sqltypes
index c5f4719426b6e7007f1ddc0a16f7459c7f659f78..1368f64141027cb52e9093c7678612cb5b144742 100644 (file)
@@ -35,7 +35,6 @@ Please consult the pymssql documentation for further information.
 from sqlalchemy.dialects.mssql.base import MSDialect
 from sqlalchemy import types as sqltypes, util, processors
 import re
-import decimal
 
 class _MSNumeric_pymssql(sqltypes.Numeric):
     def result_processor(self, dialect, type_):
index 5bba245144df309767152bcdeeffdae6c12c9cf9..93a516706a2e0222d439c6d80193868998dc4baf 100644 (file)
@@ -88,7 +88,12 @@ class _MSNumeric_pyodbc(sqltypes.Numeric):
     """
 
     def bind_processor(self, dialect):
-        super_process = super(_MSNumeric_pyodbc, self).bind_processor(dialect)
+        
+        super_process = super(_MSNumeric_pyodbc, self).\
+                        bind_processor(dialect)
+
+        if not dialect._need_decimal_fix:
+            return super_process
 
         def process(value):
             if self.asdecimal and \
@@ -106,31 +111,35 @@ class _MSNumeric_pyodbc(sqltypes.Numeric):
                 return value
         return process
     
+    # these routines needed for older versions of pyodbc.
+    # as of 2.1.8 this logic is integrated.
+    
     def _small_dec_to_string(self, value):
         return "%s0.%s%s" % (
                     (value < 0 and '-' or ''),
                     '0' * (abs(value.adjusted()) - 1),
-                    "".join([str(nint) for nint in value._int]))
+                    "".join([str(nint) for nint in value.as_tuple()[1]]))
 
     def _large_dec_to_string(self, value):
+        _int = value.as_tuple()[1]
         if 'E' in str(value):
             result = "%s%s%s" % (
                     (value < 0 and '-' or ''),
-                    "".join([str(s) for s in value._int]),
-                    "0" * (value.adjusted() - (len(value._int)-1)))
+                    "".join([str(s) for s in _int]),
+                    "0" * (value.adjusted() - (len(_int)-1)))
         else:
-            if (len(value._int) - 1) > value.adjusted():
+            if (len(_int) - 1) > value.adjusted():
                 result = "%s%s.%s" % (
                 (value < 0 and '-' or ''),
                 "".join(
-                    [str(s) for s in value._int][0:value.adjusted() + 1]),
+                    [str(s) for s in _int][0:value.adjusted() + 1]),
                 "".join(
-                    [str(s) for s in value._int][value.adjusted() + 1:]))
+                    [str(s) for s in _int][value.adjusted() + 1:]))
             else:
                 result = "%s%s" % (
                 (value < 0 and '-' or ''),
                 "".join(
-                    [str(s) for s in value._int][0:value.adjusted() + 1]))
+                    [str(s) for s in _int][0:value.adjusted() + 1]))
         return result
     
     
@@ -200,5 +209,7 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect):
         self.description_encoding = description_encoding
         self.use_scope_identity = self.dbapi and \
                         hasattr(self.dbapi.Cursor, 'nextset')
+        self._need_decimal_fix = self.dbapi and \
+                                tuple(self.dbapi.version.split(".")) < (2, 1, 8)
         
 dialect = MSDialect_pyodbc
index 87a84e514d2cce834ff59711b92b7a767f5620fc..b7d6631388a0edf09271fa009d0b6f90d1808757 100644 (file)
@@ -121,7 +121,7 @@ from sqlalchemy.engine import base
 from sqlalchemy import types as sqltypes, util, exc, processors
 from datetime import datetime
 import random
-from decimal import Decimal
+from sqlalchemy.util.compat import decimal
 import re
 
 class _OracleNumeric(sqltypes.Numeric):
@@ -148,10 +148,10 @@ class _OracleNumeric(sqltypes.Numeric):
                 def to_decimal(value):
                     if value is None:
                         return None
-                    elif isinstance(value, Decimal):
+                    elif isinstance(value, decimal.Decimal):
                         return value
                     else:
-                        return Decimal(fstring % value)
+                        return decimal.Decimal(fstring % value)
                 return to_decimal
             else:
                 if self.precision is None and self.scale is None:
@@ -569,15 +569,15 @@ class OracleDialect_cx_oracle(OracleDialect):
             self._detect_decimal = \
                 lambda value: _detect_decimal(value.replace(char, '.'))
             self._to_decimal = \
-                lambda value: Decimal(value.replace(char, '.'))
+                lambda value: decimal.Decimal(value.replace(char, '.'))
         
     def _detect_decimal(self, value):
         if "." in value:
-            return Decimal(value)
+            return decimal.Decimal(value)
         else:
             return int(value)
     
-    _to_decimal = Decimal
+    _to_decimal = decimal.Decimal
     
     def on_connect(self):
         if self.cx_oracle_ver < (5,):
index 7b1d8e6a7414d52c01f01289724408dfac65404c..3afa06eaba33baab1fde73690f7cd75ed706b4e1 100644 (file)
@@ -21,10 +21,9 @@ Passing data from/to the Interval type is not supported as of
 yet.
 
 """
-import decimal
-
 from sqlalchemy.engine import default
 from sqlalchemy import util, exc
+from sqlalchemy.util.compat import decimal
 from sqlalchemy import processors
 from sqlalchemy import types as sqltypes
 from sqlalchemy.dialects.postgresql.base import PGDialect, \
index 88e6ce67096362664af70b91350826c27d8d0bf9..b3f42c330607188992737a1c774ed1e39534f1e2 100644 (file)
@@ -86,10 +86,10 @@ The following per-statement execution options are respected:
 
 import random
 import re
-import decimal
 import logging
 
 from sqlalchemy import util, exc
+from sqlalchemy.util.compat import decimal
 from sqlalchemy import processors
 from sqlalchemy.engine import base, default
 from sqlalchemy.sql import expression
index b8f7991d5f99b8f80a6351b681fb3adb65446b8e..9abdffb6ebc9a64d1d33ff5e9e38ff550b6b2231 100644 (file)
@@ -8,7 +8,6 @@ URLs are of the form ``postgresql+pypostgresql://user@password@host:port/dbname[
 
 """
 from sqlalchemy.engine import default
-import decimal
 from sqlalchemy import util
 from sqlalchemy import types as sqltypes
 from sqlalchemy.dialects.postgresql.base import PGDialect, PGExecutionContext
index 1d955a7d9c43668ae477fc7e8723a40ca69e4c93..68b16c051ba198cde6b03040fb55f47172c1f7cc 100644 (file)
@@ -31,8 +31,8 @@ Currently *not* supported are::
 from sqlalchemy.dialects.sybase.base import SybaseDialect,\
                                             SybaseExecutionContext
 from sqlalchemy.connectors.pyodbc import PyODBCConnector
-import decimal
 from sqlalchemy import types as sqltypes, util, processors
+from sqlalchemy.util.compat import decimal
 
 class _SybNumeric_pyodbc(sqltypes.Numeric):
     """Turns Decimals with adjusted() < -6 into floats.
index 13dbc6a831502f3341e459aeccbe757c62c6b738..3e592ea51435f6e1cecf6590292c82e176f8b510 100644 (file)
@@ -22,14 +22,14 @@ __all__ = [ 'TypeEngine', 'TypeDecorator', 'AbstractType', 'UserDefinedType',
 
 import inspect
 import datetime as dt
-from decimal import Decimal as _python_Decimal
 import codecs
 
 from sqlalchemy import exc, schema
 from sqlalchemy.sql import expression, operators
 import sys
-schema.types = expression.sqltypes =sys.modules['sqlalchemy.types']
+schema.types = expression.sqltypes = sys.modules['sqlalchemy.types']
 from sqlalchemy.util import pickle
+from sqlalchemy.util.compat import decimal
 from sqlalchemy.sql.visitors import Visitable
 from sqlalchemy import util
 from sqlalchemy import processors
@@ -1047,7 +1047,38 @@ class Numeric(_DateAffinity, TypeEngine):
         overhead, and is still subject to floating point data loss - in 
         which case ``asdecimal=False`` will at least remove the extra
         conversion overhead.
-
+        
+        Note that the "cdecimal" library is a high performing alternative
+        to Python's built-in "decimal" type, which performs very poorly
+        in high volume situations.   SQLAlchemy 0.7 is tested against "cdecimal"
+        as well and supports it fully.  The type is not necessarily supported by 
+        DBAPI implementations however, most of which contain an import
+        for plain "decimal" in their source code, even though
+        some such as psycopg2 provide hooks for alternate adapters.   
+        SQLAlchemy imports "decimal" globally as well.  While the 
+        alternate "Decimal" class can be patched into SQLA's "decimal" module,
+        overall the most straightforward and foolproof way to use 
+        "cdecimal" given current support is to patch it directly 
+        into sys.modules before anything else is imported::
+        
+            import sys
+            import cdecimal
+            sys.modules["decimal"] = cdecimal
+        
+        While the global patch is a little ugly, it's particularly 
+        important to use just one decimal library at a time since 
+        Python Decimal and cdecimal Decimal objects 
+        are not currently compatible *with each other*::
+        
+            >>> import cdecimal
+            >>> import decimal
+            >>> decimal.Decimal("10") == cdecimal.Decimal("10")
+            False
+
+        SQLAlchemy will provide more automatic support of 
+        cdecimal if and when it becomes a standard part of Python
+        installations and is supported by all DBAPIs.
+        
         """
         self.precision = precision
         self.scale = scale
@@ -1085,10 +1116,10 @@ class Numeric(_DateAffinity, TypeEngine):
                 # we're a "numeric", DBAPI returns floats, convert.
                 if self.scale is not None:
                     return processors.to_decimal_processor_factory(
-                                _python_Decimal, self.scale)
+                                decimal.Decimal, self.scale)
                 else:
                     return processors.to_decimal_processor_factory(
-                                _python_Decimal)
+                                decimal.Decimal)
         else:
             if dialect.supports_native_decimal:
                 return processors.to_float
@@ -1153,7 +1184,7 @@ class Float(Numeric):
 
     def result_processor(self, dialect, coltype):
         if self.asdecimal:
-            return processors.to_decimal_processor_factory(_python_Decimal)
+            return processors.to_decimal_processor_factory(decimal.Decimal)
         else:
             return None
 
@@ -1928,9 +1959,6 @@ NULLTYPE = NullType()
 BOOLEANTYPE = Boolean()
 STRINGTYPE = String()
 
-# using VARCHAR/NCHAR so that we dont get the genericized "String"
-# type which usually resolves to TEXT/CLOB
-
 _type_map = {
     str: String(),
     # Py3K
@@ -1941,7 +1969,7 @@ _type_map = {
     int : Integer(),
     float : Numeric(),
     bool: BOOLEANTYPE,
-    _python_Decimal : Numeric(),
+    decimal.Decimal : Numeric(),
     dt.date : Date(),
     dt.datetime : DateTime(),
     dt.time : Time(),
index 961aa1f8a0f50551f0cb10a264be1f2f40236595..59dd9eaf08c16fa62c825a5f2e7e194b936feaec 100644 (file)
@@ -171,7 +171,6 @@ if win32 or jython:
 else:
     time_func = time.time 
 
-
 if sys.version_info >= (2, 5):
     def decode_slice(slc):
         """decode a slice object as sent to __getitem__.
@@ -188,3 +187,7 @@ if sys.version_info >= (2, 5):
 else:
     def decode_slice(slc):
         return (slc.start, slc.stop, slc.step)
+
+
+import decimal
+
index 53d7ff2e4a15679ba961f4c04e5983f326a3f4a4..26b6c7df4b34fd71c856507e741ba3d71bd59e86 100644 (file)
@@ -14,7 +14,7 @@ from sqlalchemy.sql import column
 from sqlalchemy.processors import to_decimal_processor_factory, \
     to_unicode_processor_factory
 from test.lib.util import gc_collect
-from decimal import Decimal as _python_Decimal
+from sqlalchemy.util.compat import decimal
 import gc
 import weakref
 from test.orm import _base
@@ -580,7 +580,7 @@ class MemUsageTest(EnsureZeroed):
     def test_DecimalResultProcessor_process(self):
         @profile_memory
         def go():
-            to_decimal_processor_factory(_python_Decimal, 10)(1.2)
+            to_decimal_processor_factory(decimal.Decimal, 10)(1.2)
         go()
 
     @testing.requires.cextensions
index ef37e4f20035f95c0594648156adf6e3de9986e4..fd43d0ca758c118cf6c4b3bf990ecc6260b9bd91 100644 (file)
@@ -48,6 +48,7 @@ def _list_dbs(*args):
         print "%20s\t%s" % (macro, file_config.get('db', macro))
     sys.exit(0)
 
+
 def _server_side_cursors(options, opt_str, value, parser):
     db_opts['server_side_cursors'] = True
 
@@ -55,23 +56,15 @@ def _engine_strategy(options, opt_str, value, parser):
     if value:
         db_opts['strategy'] = value
 
-class _ordered_map(object):
-    def __init__(self):
-        self._keys = list()
-        self._data = dict()
-
-    def __setitem__(self, key, value):
-        if key not in self._keys:
-            self._keys.append(key)
-        self._data[key] = value
-
-    def __iter__(self):
-        for key in self._keys:
-            yield self._data[key]
+pre_configure = []
+post_configure = []
 
-# at one point in refactoring, modules were injecting into the config
-# process.  this could probably just become a list now.
-post_configure = _ordered_map()
+def _monkeypatch_cdecimal(options, file_config):
+    if options.cdecimal:
+        import sys
+        import cdecimal
+        sys.modules['decimal'] = cdecimal
+pre_configure.append(_monkeypatch_cdecimal)
 
 def _engine_uri(options, file_config):
     global db_label, db_url
@@ -88,7 +81,7 @@ def _engine_uri(options, file_config):
             raise RuntimeError(
                 "Unknown engine.  Specify --dbs for known engines.")
         db_url = file_config.get('db', db_label)
-post_configure['engine_uri'] = _engine_uri
+post_configure.append(_engine_uri)
 
 def _require(options, file_config):
     if not(options.require or
@@ -114,20 +107,20 @@ def _require(options, file_config):
             if seen:
                 continue
             pkg_resources.require(requirement)
-post_configure['require'] = _require
+post_configure.append(_require)
 
 def _engine_pool(options, file_config):
     if options.mockpool:
         from sqlalchemy import pool
         db_opts['poolclass'] = pool.AssertionPool
-post_configure['engine_pool'] = _engine_pool
+post_configure.append(_engine_pool)
 
 def _create_testing_engine(options, file_config):
     from test.lib import engines, testing
     global db
     db = engines.testing_engine(db_url, db_opts)
     testing.db = db
-post_configure['create_engine'] = _create_testing_engine
+post_configure.append(_create_testing_engine)
 
 def _prep_testing_database(options, file_config):
     from test.lib import engines
@@ -149,7 +142,7 @@ def _prep_testing_database(options, file_config):
             md.drop_all()
         e.dispose()
 
-post_configure['prep_db'] = _prep_testing_database
+post_configure.append(_prep_testing_database)
 
 def _set_table_options(options, file_config):
     from test.lib import schema
@@ -161,7 +154,7 @@ def _set_table_options(options, file_config):
 
     if options.mysql_engine:
         table_options['mysql_engine'] = options.mysql_engine
-post_configure['table_options'] = _set_table_options
+post_configure.append(_set_table_options)
 
 def _reverse_topological(options, file_config):
     if options.reversetop:
@@ -169,5 +162,5 @@ def _reverse_topological(options, file_config):
         from sqlalchemy import topological
         from test.lib.util import RandomSet
         topological.set = unitofwork.set = session.set = mapper.set = dependency.set = RandomSet
-post_configure['topological'] = _reverse_topological
+post_configure.append(_reverse_topological)
 
index 838ad0042a353584ae23bc6a1e024e2e9dc3bf52..5ffa8b89fd07a4593eafc2065a3222f9458ba0df 100644 (file)
@@ -15,7 +15,9 @@ from test.bootstrap import config
 from test.bootstrap.config import (
     _create_testing_engine, _engine_pool, _engine_strategy, _engine_uri, _list_dbs, _log,
     _prep_testing_database, _require, _reverse_topological, _server_side_cursors,
-    _set_table_options, base_config, db, db_label, db_url, file_config, post_configure)
+    _monkeypatch_cdecimal,
+    _set_table_options, base_config, db, db_label, db_url, file_config, post_configure, 
+    pre_configure)
 
 log = logging.getLogger('nose.plugins.sqlalchemy')
 
@@ -56,6 +58,8 @@ class NoseSQLAlchemy(Plugin):
         opt("--reversetop", action="store_true", dest="reversetop", default=False,
             help="Use a random-ordering set implementation in the ORM (helps "
                   "reveal dependency issues)")
+        opt("--with-cdecimal", action="store_true", dest="cdecimal", default=False,
+            help="Monkeypatch the cdecimal library into Python 'decimal' for all tests")
         opt("--unhashable", action="store_true", dest="unhashable", default=False,
             help="Disallow SQLAlchemy from performing a hash() on mapped test objects.")
         opt("--noncomparable", action="store_true", dest="noncomparable", default=False,
@@ -79,7 +83,9 @@ class NoseSQLAlchemy(Plugin):
     def configure(self, options, conf):
         Plugin.configure(self, options, conf)
         self.options = options
-        
+        for fn in pre_configure:
+            fn(self.options, file_config)
+            
     def begin(self):
         global testing, requires, util
         from test.lib import testing, requires
index 7d43d594bcf6c255c29ef3145fb70c90d2fe3660..6ed420d5c42cf2aad85d4e1d262df51118b524ea 100644 (file)
@@ -4,7 +4,7 @@ from test.lib.testing import eq_
 import StringIO, sys
 from sqlalchemy import *
 from sqlalchemy import exc, sql
-from decimal import Decimal
+from sqlalchemy.util.compat import decimal
 from sqlalchemy.databases import maxdb
 from test.lib import *
 
@@ -40,7 +40,7 @@ class ReflectionTest(TestBase, AssertsExecutionResults):
                 _t.create()
             t = Table('dectest', meta, autoload=True)
 
-            vals = [Decimal('2.2'), Decimal('23'), Decimal('2.4'), 25]
+            vals = [decimal.Decimal('2.2'), decimal.Decimal('23'), decimal.Decimal('2.4'), 25]
             cols = ['d1','d2','n1','i1']
             t.insert().execute(dict(zip(cols,vals)))
             roundtrip = list(t.select().execute())
index 63aa874fd0ce26d754c287c61131b8f1ddabb780..f3643c4df348043fcc305bedbe170b4131a3eb81 100644 (file)
@@ -14,6 +14,7 @@ from sqlalchemy.engine import url
 from test.lib import *
 from test.lib.testing import eq_, emits_warning_on, \
     assert_raises_message
+from sqlalchemy.util.compat import decimal
 
 class CompileTest(TestBase, AssertsCompiledSQL):
     __dialect__ = mssql.dialect()
@@ -1043,7 +1044,6 @@ class TypesTest(TestBase, AssertsExecutionResults, ComparesTables):
     @testing.fails_on_everything_except('mssql+pyodbc',
             'this is some pyodbc-specific feature')
     def test_decimal_notation(self):
-        import decimal
         numeric_table = Table('numeric_table', metadata, Column('id',
                               Integer, Sequence('numeric_id_seq',
                               optional=True), primary_key=True),
index 3a0fbac9a4f9ee7f21d3a22c42e1c01a6d9ffb92..d842c7fc226b5aa7d8bdb914685f6819c2864835 100644 (file)
@@ -10,7 +10,7 @@ from test.lib.engines import testing_engine
 from sqlalchemy.dialects.oracle import cx_oracle, base as oracle
 from sqlalchemy.engine import default
 from sqlalchemy.util import jython
-from decimal import Decimal
+from sqlalchemy.util.compat import decimal
 import datetime
 import os
 
@@ -774,12 +774,12 @@ class TypesTest(TestBase, AssertsCompiledSQL):
             ):
                 for i, (val, type_) in enumerate((
                     (1, int),
-                    (Decimal("5.2"), Decimal),
+                    (decimal.Decimal("5.2"), decimal.Decimal),
                     (6.5, float),
                     (8.5, float),
                     (9.5, float),
                     (12, int),
-                    (Decimal("14.85"), Decimal),
+                    (decimal.Decimal("14.85"), decimal.Decimal),
                     (15.76, float),
                 )):
                     eq_(row[i], val)
@@ -809,8 +809,8 @@ class TypesTest(TestBase, AssertsCompiledSQL):
         foo.create()
         
         foo.insert().execute(
-            {'idata':5, 'ndata':Decimal("45.6"), 'ndata2':Decimal("45.0"), 
-                    'nidata':Decimal('53'), 'fdata':45.68392},
+            {'idata':5, 'ndata':decimal.Decimal("45.6"), 'ndata2':decimal.Decimal("45.0"), 
+                    'nidata':decimal.Decimal('53'), 'fdata':45.68392},
         )
 
         stmt = """
@@ -825,10 +825,10 @@ class TypesTest(TestBase, AssertsCompiledSQL):
         
         
         row = testing.db.execute(stmt).fetchall()[0]
-        eq_([type(x) for x in row], [int, Decimal, Decimal, int, float])
+        eq_([type(x) for x in row], [int, decimal.Decimal, decimal.Decimal, int, float])
         eq_(
             row, 
-            (5, Decimal('45.6'), Decimal('45'), 53, 45.683920000000001)
+            (5, decimal.Decimal('45.6'), decimal.Decimal('45'), 53, 45.683920000000001)
         )
 
         # with a nested subquery, 
@@ -852,10 +852,10 @@ class TypesTest(TestBase, AssertsCompiledSQL):
         FROM dual
         """
         row = testing.db.execute(stmt).fetchall()[0]
-        eq_([type(x) for x in row], [int, Decimal, int, int, Decimal])
+        eq_([type(x) for x in row], [int, decimal.Decimal, int, int, decimal.Decimal])
         eq_(
             row, 
-            (5, Decimal('45.6'), 45, 53, Decimal('45.68392'))
+            (5, decimal.Decimal('45.6'), 45, 53, decimal.Decimal('45.68392'))
         )
         
         row = testing.db.execute(text(stmt, 
@@ -866,9 +866,9 @@ class TypesTest(TestBase, AssertsCompiledSQL):
                                         'nidata':Numeric(5, 0),
                                         'fdata':Float()
                                 })).fetchall()[0]
-        eq_([type(x) for x in row], [int, Decimal, Decimal, Decimal, float])
+        eq_([type(x) for x in row], [int, decimal.Decimal, decimal.Decimal, decimal.Decimal, float])
         eq_(row, 
-            (5, Decimal('45.6'), Decimal('45'), Decimal('53'), 45.683920000000001)
+            (5, decimal.Decimal('45.6'), decimal.Decimal('45'), decimal.Decimal('53'), 45.683920000000001)
         )
         
         stmt = """
@@ -895,8 +895,8 @@ class TypesTest(TestBase, AssertsCompiledSQL):
         WHERE ROWNUM >= 0) anon_1
         """
         row =testing.db.execute(stmt).fetchall()[0]
-        eq_([type(x) for x in row], [int, Decimal, int, int, Decimal])
-        eq_(row, (5, Decimal('45.6'), 45, 53, Decimal('45.68392')))
+        eq_([type(x) for x in row], [int, decimal.Decimal, int, int, decimal.Decimal])
+        eq_(row, (5, decimal.Decimal('45.6'), 45, 53, decimal.Decimal('45.68392')))
 
         row = testing.db.execute(text(stmt, 
                                 typemap={
@@ -906,9 +906,9 @@ class TypesTest(TestBase, AssertsCompiledSQL):
                                         'anon_1_nidata':Numeric(5, 0), 
                                         'anon_1_fdata':Float()
                                 })).fetchall()[0]
-        eq_([type(x) for x in row], [int, Decimal, Decimal, Decimal, float])
+        eq_([type(x) for x in row], [int, decimal.Decimal, decimal.Decimal, decimal.Decimal, float])
         eq_(row, 
-            (5, Decimal('45.6'), Decimal('45'), Decimal('53'), 45.683920000000001)
+            (5, decimal.Decimal('45.6'), decimal.Decimal('45'), decimal.Decimal('53'), 45.683920000000001)
         )
 
         row = testing.db.execute(text(stmt, 
@@ -919,9 +919,9 @@ class TypesTest(TestBase, AssertsCompiledSQL):
                                         'anon_1_nidata':Numeric(5, 0, asdecimal=False), 
                                         'anon_1_fdata':Float(asdecimal=True)
                                 })).fetchall()[0]
-        eq_([type(x) for x in row], [int, float, float, float, Decimal])
+        eq_([type(x) for x in row], [int, float, float, float, decimal.Decimal])
         eq_(row, 
-            (5, 45.6, 45, 53, Decimal('45.68392'))
+            (5, 45.6, 45, 53, decimal.Decimal('45.68392'))
         )
 
         
@@ -1064,11 +1064,11 @@ class EuroNumericTest(TestBase):
     @testing.provide_metadata
     def test_output_type_handler(self):
         for stmt, exp, kw in [
-            ("SELECT 0.1 FROM DUAL", Decimal("0.1"), {}),
+            ("SELECT 0.1 FROM DUAL", decimal.Decimal("0.1"), {}),
             ("SELECT 15 FROM DUAL", 15, {}),
-            ("SELECT CAST(15 AS NUMERIC(3, 1)) FROM DUAL", Decimal("15"), {}),
-            ("SELECT CAST(0.1 AS NUMERIC(5, 2)) FROM DUAL", Decimal("0.1"), {}),
-            ("SELECT :num FROM DUAL", Decimal("2.5"), {'num':Decimal("2.5")})
+            ("SELECT CAST(15 AS NUMERIC(3, 1)) FROM DUAL", decimal.Decimal("15"), {}),
+            ("SELECT CAST(0.1 AS NUMERIC(5, 2)) FROM DUAL", decimal.Decimal("0.1"), {}),
+            ("SELECT :num FROM DUAL", decimal.Decimal("2.5"), {'num':decimal.Decimal("2.5")})
         ]:
             test_exp = self.engine.scalar(stmt, **kw)
             eq_(
index bff1fba68adb1c5d581413d68588877e5de6cfa6..cfccb9bb1c2cca06c16bc489537d0cd99010e437 100644 (file)
@@ -2,12 +2,12 @@
 from test.lib.testing import eq_, assert_raises, assert_raises_message
 from test.lib import  engines
 import datetime
-import decimal
 from sqlalchemy import *
 from sqlalchemy.orm import *
 from sqlalchemy import exc, schema, types
 from sqlalchemy.dialects.postgresql import base as postgresql
 from sqlalchemy.engine.strategies import MockEngineStrategy
+from sqlalchemy.util.compat import decimal
 from test.lib import *
 from test.lib.util import round_decimal
 from sqlalchemy.sql import table, column
@@ -480,7 +480,7 @@ class NumericInterpretationTest(TestBase):
     
     def test_numeric_codes(self):
         from sqlalchemy.dialects.postgresql import pg8000, psycopg2, base
-        from decimal import Decimal
+        from sqlalchemy.util.compat import decimal
         
         for dialect in (pg8000.dialect(), psycopg2.dialect()):
             
@@ -491,7 +491,7 @@ class NumericInterpretationTest(TestBase):
                 val = 23.7
                 if proc is not None:
                     val = proc(val)
-                assert val in (23.7, Decimal("23.7"))
+                assert val in (23.7, decimal.Decimal("23.7"))
     
 class InsertTest(TestBase, AssertsExecutionResults):
 
index b8cc05a818b4daaf4a810ffa79471df5a472b65a..4c9892852f7a22432122115550319fcf3ee517ca 100644 (file)
@@ -1,4 +1,5 @@
 from sqlalchemy.util import jython, defaultdict, decorator
+from sqlalchemy.util.compat import decimal
 
 import gc
 import time
@@ -44,8 +45,6 @@ def round_decimal(value, prec):
     if isinstance(value, float):
         return round(value, prec)
     
-    import decimal
-
     # can also use shift() here but that is 2.6 only
     return (value * decimal.Decimal("1" + "0" * prec)).to_integral(decimal.ROUND_FLOOR) / \
                         pow(10, prec)
index ad074ee533fee97e72b2ad100fc15a366ffc5879..a19be95795ea0640b8279bba96d25ef4507417ec 100644 (file)
@@ -1,6 +1,6 @@
 # -*- encoding: utf8 -*-
 from datetime import *
-from decimal import Decimal
+from sqlalchemy.util.compat import decimal
 #from fastdec import mpd as Decimal
 from cPickle import dumps, loads
 
index 396eaaf9ba061f4958314832628ad04f7e6ea519..0fb2ca5f7f8eabda2af45b7f61303457daa0ef76 100644 (file)
@@ -10,7 +10,7 @@ from sqlalchemy import types as sqltypes
 from test.lib import *
 from sqlalchemy.sql.functions import GenericFunction
 from test.lib.testing import eq_
-from decimal import Decimal as _python_Decimal
+from sqlalchemy.util.compat import decimal
 from test.lib import testing
 from sqlalchemy.databases import *
 
@@ -107,7 +107,7 @@ class CompileTest(TestBase, AssertsCompiledSQL):
                             ((datetime.date(2007, 10, 5), 
                                 datetime.date(2005, 10, 15)), sqltypes.Date),
                             ((3, 5), sqltypes.Integer),
-                            ((_python_Decimal(3), _python_Decimal(5)), sqltypes.Numeric),
+                            ((decimal.Decimal(3), decimal.Decimal(5)), sqltypes.Numeric),
                             (("foo", "bar"), sqltypes.String),
                             ((datetime.datetime(2007, 10, 5, 8, 3, 34), 
                                 datetime.datetime(2005, 10, 15, 14, 45, 33)), sqltypes.DateTime)
index 73d99ac6a401ff0fd426dc4b78cb353902f3198a..3d9be543c74cc00197013e76b3ba181fb5380f6a 100644 (file)
@@ -11,7 +11,7 @@ from sqlalchemy.databases import *
 from test.lib.schema import Table, Column
 from test.lib import *
 from test.lib.util import picklers
-from decimal import Decimal
+from sqlalchemy.util.compat import decimal
 from test.lib.util import round_decimal
 
 
@@ -1262,8 +1262,8 @@ class NumericTest(TestBase):
     def test_numeric_as_decimal(self):
         self._do_test(
             Numeric(precision=8, scale=4),
-            [15.7563, Decimal("15.7563"), None],
-            [Decimal("15.7563"), None], 
+            [15.7563, decimal.Decimal("15.7563"), None],
+            [decimal.Decimal("15.7563"), None], 
         )
 
     def test_numeric_as_float(self):
@@ -1274,7 +1274,7 @@ class NumericTest(TestBase):
 
         self._do_test(
             Numeric(precision=8, scale=4, asdecimal=False),
-            [15.7563, Decimal("15.7563"), None],
+            [15.7563, decimal.Decimal("15.7563"), None],
             [15.7563, None],
             filter_ = filter_
         )
@@ -1282,15 +1282,15 @@ class NumericTest(TestBase):
     def test_float_as_decimal(self):
         self._do_test(
             Float(precision=8, asdecimal=True),
-            [15.7563, Decimal("15.7563"), None],
-            [Decimal("15.7563"), None], 
+            [15.7563, decimal.Decimal("15.7563"), None],
+            [decimal.Decimal("15.7563"), None], 
             filter_ = lambda n:n is not None and round(n, 5) or None
         )
 
     def test_float_as_float(self):
         self._do_test(
             Float(precision=8),
-            [15.7563, Decimal("15.7563")],
+            [15.7563, decimal.Decimal("15.7563")],
             [15.7563],
             filter_ = lambda n:n is not None and round(n, 5) or None
         )
@@ -1390,18 +1390,18 @@ class NumericRawSQLTest(TestBase):
     @testing.fails_on('sqlite', "Doesn't provide Decimal results natively")
     @testing.provide_metadata
     def test_decimal_fp(self):
-        t = self._fixture(metadata, Numeric(10, 5), Decimal("45.5"))
+        t = self._fixture(metadata, Numeric(10, 5), decimal.Decimal("45.5"))
         val = testing.db.execute("select val from t").scalar()
-        assert isinstance(val, Decimal)
-        eq_(val, Decimal("45.5"))
+        assert isinstance(val, decimal.Decimal)
+        eq_(val, decimal.Decimal("45.5"))
 
     @testing.fails_on('sqlite', "Doesn't provide Decimal results natively")
     @testing.provide_metadata
     def test_decimal_int(self):
-        t = self._fixture(metadata, Numeric(10, 5), Decimal("45"))
+        t = self._fixture(metadata, Numeric(10, 5), decimal.Decimal("45"))
         val = testing.db.execute("select val from t").scalar()
-        assert isinstance(val, Decimal)
-        eq_(val, Decimal("45"))
+        assert isinstance(val, decimal.Decimal)
+        eq_(val, decimal.Decimal("45"))
 
     @testing.provide_metadata
     def test_ints(self):