]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
some tests, should be OK
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 14 Dec 2010 01:23:24 +0000 (20:23 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 14 Dec 2010 01:23:24 +0000 (20:23 -0500)
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/types.py
test/sql/test_types.py

index 7dd7400ea60408b6e72ed0cb8e58e2b32c9da936..4c0a0089042622dfa6856a9f620480e227c69f04 100644 (file)
@@ -175,8 +175,9 @@ class REAL(sqltypes.Float):
 
     __visit_name__ = 'REAL'
 
-    def __init__(self):
-        super(REAL, self).__init__(precision=24)
+    def __init__(self, **kw):
+        kw.setdefault('precision', 24)
+        super(REAL, self).__init__(**kw)
 
 class TINYINT(sqltypes.Integer):
     __visit_name__ = 'TINYINT'
@@ -258,7 +259,8 @@ class SMALLDATETIME(_DateTimeBase, sqltypes.DateTime):
 class DATETIME2(_DateTimeBase, sqltypes.DateTime):
     __visit_name__ = 'DATETIME2'
     
-    def __init__(self, precision=None, **kwargs):
+    def __init__(self, precision=None, **kw):
+        super(DATETIME2, self).__init__(**kw)
         self.precision = precision
 
 
index deeebf0f906c247d54f126da6f6a75c902e970c5..fd99a16b55aea9e80354eea49e949620f9555d78 100644 (file)
@@ -771,7 +771,7 @@ class CHAR(_StringType, sqltypes.CHAR):
 
     __visit_name__ = 'CHAR'
 
-    def __init__(self, length, **kwargs):
+    def __init__(self, length=None, **kwargs):
         """Construct a CHAR.
 
         :param length: Maximum data length, in characters.
index 85ac3192ff65a8bc346f276766728399d6a70dc2..447938461e5d048bc84ee3c6e42479ba0966ed39 100644 (file)
@@ -131,7 +131,7 @@ class TypeEngine(AbstractType):
         else:
             return self.__class__
 
-    def dialect_impl(self, dialect, **kwargs):
+    def dialect_impl(self, dialect):
         """Return a dialect-specific implementation for this type."""
         
         try:
@@ -149,22 +149,6 @@ class TypeEngine(AbstractType):
             d['bind'] = bp = d['impl'].bind_processor(dialect)
             return bp
     
-    def _dialect_info(self, dialect):
-        """Return a dialect-specific registry containing bind/result processors."""
-        
-        if self in dialect._type_memos:
-            return dialect._type_memos[self]
-        else:
-            impl = self._gen_dialect_impl(dialect)
-            # the impl we put in here
-            # must not have any references to self.
-            if impl is self:
-                impl = self.adapt(type(self))
-            dialect._type_memos[self] = d = {
-                'impl':impl,
-            }
-            return d
-        
     def _cached_result_processor(self, dialect, coltype):
         """Return a dialect-specific result processor for this type."""
 
@@ -172,11 +156,28 @@ class TypeEngine(AbstractType):
             return dialect._type_memos[self][coltype]
         except KeyError:
             d = self._dialect_info(dialect)
-            # another key assumption.  DBAPI type codes are
-            # constants.   
+            # key assumption: DBAPI type codes are
+            # constants.  Else this dictionary would
+            # grow unbounded.
             d[coltype] = rp = d['impl'].result_processor(dialect, coltype)
             return rp
 
+    def _dialect_info(self, dialect):
+        """Return a dialect-specific registry which 
+        caches a dialect-specific implementation, bind processing
+        function, and one or more result processing functions."""
+        
+        if self in dialect._type_memos:
+            return dialect._type_memos[self]
+        else:
+            impl = self._gen_dialect_impl(dialect)
+            if impl is self:
+                impl = self.adapt(type(self))
+            # this can't be self, else we create a cycle
+            assert impl is not self
+            dialect._type_memos[self] = d = {'impl':impl}
+            return d
+
     def _gen_dialect_impl(self, dialect):
         return dialect.type_descriptor(self)
         
@@ -792,7 +793,7 @@ class String(Concatenable, TypeEngine):
                     length=self.length,
                     convert_unicode=self.convert_unicode,
                     unicode_error=self.unicode_error,
-                    _warn_on_bytestring=True,
+                    _warn_on_bytestring=self._warn_on_bytestring,
                     **kw
                     )
 
@@ -1171,7 +1172,9 @@ class Float(Numeric):
     """
 
     __visit_name__ = 'float'
-
+    
+    scale = None
+    
     def __init__(self, precision=None, asdecimal=False, **kwargs):
         """
         Construct a Float.
@@ -1787,7 +1790,7 @@ class Interval(_DateAffinity, TypeDecorator):
         self.day_precision = day_precision
 
     def adapt(self, cls, **kw):
-        if self.native:
+        if self.native and hasattr(cls, '_adapt_from_generic_interval'):
             return cls._adapt_from_generic_interval(self, **kw)
         else:
             return cls(**kw)
index 3d9be543c74cc00197013e76b3ba181fb5380f6a..f9307daafa6a9d8a0f4c9ff5deb57286e2295ac2 100644 (file)
@@ -3,19 +3,44 @@ from test.lib.testing import eq_, assert_raises, assert_raises_message
 import decimal
 import datetime, os, re
 from sqlalchemy import *
-from sqlalchemy import exc, types, util, schema
+from sqlalchemy import exc, types, util, schema, dialects
+for name in dialects.__all__:
+    __import__("sqlalchemy.dialects.%s" % name)
 from sqlalchemy.sql import operators, column, table
 from test.lib.testing import eq_
 import sqlalchemy.engine.url as url
-from sqlalchemy.databases import *
 from test.lib.schema import Table, Column
 from test.lib import *
 from test.lib.util import picklers
 from sqlalchemy.util.compat import decimal
 from test.lib.util import round_decimal
 
-
 class AdaptTest(TestBase):
+    def _all_dialect_modules(self):
+        return [
+            getattr(dialects, d)
+            for d in dialects.__all__
+            if not d.startswith('_')
+        ]
+        
+    def _all_dialects(self):
+        return [d.base.dialect() for d in 
+                self._all_dialect_modules()]
+    
+    def _all_types(self):
+        def types_for_mod(mod):
+            for key in dir(mod):
+                typ = getattr(mod, key)
+                if not isinstance(typ, type) or not issubclass(typ, types.TypeEngine):
+                    continue
+                yield typ
+        
+        for typ in types_for_mod(types):
+            yield typ
+        for dialect in self._all_dialect_modules():
+            for typ in types_for_mod(dialect):
+                yield typ
+        
     def test_uppercase_rendering(self):
         """Test that uppercase types from types.py always render as their
         type.
@@ -27,12 +52,7 @@ class AdaptTest(TestBase):
         
         """
         
-        for dialect in [
-                oracle.dialect(), 
-                mysql.dialect(), 
-                postgresql.dialect(), 
-                sqlite.dialect(), 
-                mssql.dialect()]: 
+        for dialect in self._all_dialects():
             for type_, expected in (
                 (FLOAT, "FLOAT"),
                 (NUMERIC, "NUMERIC"),
@@ -49,7 +69,7 @@ class AdaptTest(TestBase):
                                     "NVARCHAR2(10)")),
                 (CHAR, "CHAR"),
                 (NCHAR, ("NCHAR", "NATIONAL CHAR")),
-                (BLOB, "BLOB"),
+                (BLOB, ("BLOB", "BLOB SUB_TYPE 0")),
                 (BOOLEAN, ("BOOLEAN", "BOOL"))
             ):
                 if isinstance(expected, str):
@@ -65,7 +85,40 @@ class AdaptTest(TestBase):
                 assert str(types.to_instance(type_)) in expected, \
                     "default str() of type %r not expected, %r" % \
                     (type_, expected)
-                
+    
+    @testing.uses_deprecated()
+    def test_adapt_method(self):
+        """ensure all types have a working adapt() method,
+        which creates a distinct copy.   
+        
+        The distinct copy ensures that when we cache
+        the adapted() form of a type against the original
+        in a weak key dictionary, a cycle is not formed.
+        
+        This test doesn't test type-specific arguments of
+        adapt() beyond their defaults.
+        
+        """
+        
+        for typ in self._all_types():
+            if typ in (types.TypeDecorator, types.TypeEngine):
+                continue
+            elif typ is dialects.postgresql.ARRAY:
+                t1 = typ(String)
+            else:
+                t1 = typ()
+            for cls in [typ] + typ.__subclasses__():
+                if not issubclass(typ, types.Enum) and \
+                    issubclass(cls, types.Enum):
+                    continue
+                t2 = t1.adapt(cls)
+                assert t1 is not t2
+                for k in t1.__dict__:
+                    if k == 'impl':
+                        continue
+                    eq_(getattr(t2, k), t1.__dict__[k])
+        
+        
 class TypeAffinityTest(TestBase):
     def test_type_affinity(self):
         for type_, affin in [
@@ -155,7 +208,7 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL):
             (Float(2), "FLOAT(2)", {'precision':4}),
             (Numeric(19, 2), "NUMERIC(19, 2)", {}),
         ]:
-            for dialect_ in (postgresql, mssql, mysql):
+            for dialect_ in (dialects.postgresql, dialects.mssql, dialects.mysql):
                 dialect_ = dialect_.dialect()
                 
                 raw_impl = types.to_instance(impl_, **kw)
@@ -188,8 +241,8 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL):
                 else:
                     return super(MyType, self).load_dialect_impl(dialect)
         
-        sl = sqlite.dialect()
-        pg = postgresql.dialect()
+        sl = dialects.sqlite.dialect()
+        pg = dialects.postgresql.dialect()
         t = MyType()
         self.assert_compile(t, "VARCHAR(50)", dialect=sl)
         self.assert_compile(t, "FLOAT", dialect=pg)
@@ -1082,12 +1135,12 @@ class CompileTest(TestBase, AssertsCompiledSQL):
         for type_, expected in (
             (String(), "VARCHAR"),
             (Integer(), "INTEGER"),
-            (postgresql.INET(), "INET"),
-            (postgresql.FLOAT(), "FLOAT"),
-            (mysql.REAL(precision=8, scale=2), "REAL(8, 2)"),
-            (postgresql.REAL(), "REAL"),
+            (dialects.postgresql.INET(), "INET"),
+            (dialects.postgresql.FLOAT(), "FLOAT"),
+            (dialects.mysql.REAL(precision=8, scale=2), "REAL(8, 2)"),
+            (dialects.postgresql.REAL(), "REAL"),
             (INTEGER(), "INTEGER"),
-            (mysql.INTEGER(display_width=5), "INTEGER(5)")
+            (dialects.mysql.INTEGER(display_width=5), "INTEGER(5)")
         ):
             self.assert_compile(type_, expected)