]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
clean up NumericTest to use a consistent one column at a time system
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 18 Mar 2010 19:34:09 +0000 (15:34 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 18 Mar 2010 19:34:09 +0000 (15:34 -0400)
test/sql/test_types.py

index 9e82fc1ef428786cdb84f200da0f4ad339285735..123eab48ca60756678016d6086a85a437949d47f 100644 (file)
@@ -1075,7 +1075,7 @@ class DateTest(TestBase, AssertsExecutionResults):
         finally:
             t.drop(checkfirst=True)
 
-class StringTest(TestBase, AssertsExecutionResults):
+class StringTest(TestBase):
 
     @testing.requires.unbounded_varchar
     def test_nolength_string(self):
@@ -1085,74 +1085,74 @@ class StringTest(TestBase, AssertsExecutionResults):
         foo.create()
         foo.drop()
 
-class NumericTest(TestBase, AssertsExecutionResults):
-    @classmethod
-    def setup_class(cls):
-        global numeric_table, metadata
+class NumericTest(TestBase):
+    def setup(self):
+        global metadata
         metadata = MetaData(testing.db)
-        numeric_table = Table('numeric_table', metadata,
-            Column('id', Integer, Sequence('numeric_id_seq', optional=True), primary_key=True),
-            Column('numericcol', Numeric(precision=10, scale=2, asdecimal=False)),
-            Column('floatcol', Float(precision=10, )),
-            Column('ncasdec', Numeric(precision=10, scale=2)),
-            Column('fcasdec', Float(precision=10, asdecimal=True))
-        )
-        metadata.create_all()
-
-    @classmethod
-    def teardown_class(cls):
+        
+    def teardown(self):
         metadata.drop_all()
+        
+    def _do_test(self, type_, input_, output, filter_ = None):
+        t = Table('t', metadata, Column('x', type_))
+        t.create()
+        t.insert().execute([{'x':x} for x in input_])
+
+        result = set([row[0] for row in t.select().execute()])
+        output = set(output)
+        if filter_:
+            result = set(filter_(x) for x in result)
+            output = set(filter_(x) for x in output)
+        eq_(result, output)
+        
+    def test_numeric_as_decimal(self):
+        self._do_test(
+            Numeric(precision=8, scale=4),
+            [15.7563, Decimal("15.7563")],
+            [Decimal("15.7563")], 
+        )
 
-    @engines.close_first
-    def teardown(self):
-        numeric_table.delete().execute()
+    def test_numeric_as_float(self):
+        self._do_test(
+            Numeric(precision=8, scale=4, asdecimal=False),
+            [15.7563, Decimal("15.7563")],
+            [15.7563]
+        )
 
-    def test_decimal(self):
-        from decimal import Decimal
-        numeric_table.insert().execute(
-            numericcol=3.5, floatcol=5.6, ncasdec=12.4, fcasdec=15.75)
-            
-        numeric_table.insert().execute(
-            numericcol=Decimal("3.5"), floatcol=Decimal("5.6"),
-            ncasdec=Decimal("12.4"), fcasdec=Decimal("15.75"))
-
-        l = numeric_table.select().order_by(numeric_table.c.id).execute().fetchall()
-        rounded = [
-            (l[0][0], l[0][1], round(l[0][2], 5), l[0][3], l[0][4]),
-            (l[1][0], l[1][1], round(l[1][2], 5), l[1][3], l[1][4]),
-        ]
-        testing.eq_(rounded, [
-            (1, 3.5, 5.6, Decimal("12.4"), Decimal("15.75")),
-            (2, 3.5, 5.6, Decimal("12.4"), Decimal("15.75")),
-        ])
+    def test_float_as_decimal(self):
+        self._do_test(
+            Float(precision=8, asdecimal=True),
+            [15.7563, Decimal("15.7563")],
+            [Decimal("15.7563")], 
+            filter_ = lambda n:round(n, 5)
+        )
 
+    def test_float_as_float(self):
+        self._do_test(
+            Float(precision=8),
+            [15.7563, Decimal("15.7563")],
+            [15.7563],
+            filter_ = lambda n:round(n, 5)
+        )
+        
     def test_precision_decimal(self):
+        numbers = set([
+            decimal.Decimal("54.234246451650"),
+            decimal.Decimal("876734.594069654000"),
+            decimal.Decimal("0.004354"), 
+            decimal.Decimal("900.0"), 
+        ])
+        if testing.against('sqlite', 'sybase+pysybase', 'oracle+cx_oracle'):
+            filter_ = lambda n:round_decimal(n, 11)
+        else:
+            filter_ = None
             
-        t = Table('t', MetaData(), Column('x', Numeric(precision=18, scale=12)))
-        t.create(testing.db)
-        try:
-            numbers = set(
-            [
-                decimal.Decimal("54.234246451650"),
-                decimal.Decimal("876734.594069654000"),
-                decimal.Decimal("0.004354"), 
-                decimal.Decimal("900.0"), 
-            ])
-
-            testing.db.execute(t.insert(), [{'x':x} for x in numbers])
-
-            ret = set([row[0] for row in testing.db.execute(t.select()).fetchall()])
-            
-            if testing.against('sqlite', 'sybase+pysybase', 'oracle+cx_oracle'):
-                numbers = set(round_decimal(n, 11) for n in numbers)
-                ret = set(round_decimal(n, 11) for n in ret)
-            else:
-                numbers = set(n for n in numbers)
-                ret = set(n for n in ret)
-            
-            eq_(numbers, ret)
-        finally:
-            t.drop(testing.db)
+        self._do_test(
+            Numeric(precision=18, scale=12),
+            numbers,
+            numbers,
+            filter_=filter_
+        )
 
     def test_enotation_decimal(self):
         """test exceedingly small decimals.
@@ -1161,42 +1161,22 @@ class NumericTest(TestBase, AssertsExecutionResults):
         is greater than 6.
         
         """
-
-        t = Table('t', MetaData(), Column('x', Numeric(precision=18, scale=12)))
-        t.create(testing.db)
-        try:
-            numbers = set([
-                decimal.Decimal('1E-2'),
-                decimal.Decimal('1E-3'),
-                decimal.Decimal('1E-4'),
-                decimal.Decimal('1E-5'),
-                decimal.Decimal('1E-6'),
-                decimal.Decimal('1E-7'),
-                decimal.Decimal('1E-8'),
-            ])
-
-            testing.db.execute(t.insert(), [{'x':x} for x in numbers])
-
-            ret = set([row[0] for row in testing.db.execute(t.select()).fetchall()])
-            
-            numbers = set(n for n in numbers)
-            ret = set(n for n in ret)
-            
-            eq_(numbers, ret)
-        finally:
-            t.drop(testing.db)
         
-
-    def test_decimal_fallback(self):
-        from decimal import Decimal
-
-        numeric_table.insert().execute(ncasdec=12.4, fcasdec=15.75)
-        numeric_table.insert().execute(ncasdec=Decimal("12.4"),
-                                       fcasdec=Decimal("15.75"))
-
-        for row in numeric_table.select().execute().fetchall():
-            assert isinstance(row['ncasdec'], decimal.Decimal)
-            assert isinstance(row['fcasdec'], decimal.Decimal)
+        numbers = set([
+            decimal.Decimal('1E-2'),
+            decimal.Decimal('1E-3'),
+            decimal.Decimal('1E-4'),
+            decimal.Decimal('1E-5'),
+            decimal.Decimal('1E-6'),
+            decimal.Decimal('1E-7'),
+            decimal.Decimal('1E-8'),
+        ])
+        self._do_test(
+            Numeric(precision=18, scale=12),
+            numbers,
+            numbers
+        )
+        
 
             
 class IntervalTest(TestBase, AssertsExecutionResults):