From: Mike Bayer Date: Tue, 14 Dec 2010 01:23:24 +0000 (-0500) Subject: some tests, should be OK X-Git-Tag: rel_0_7b1~162 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=bfaa97dbce7e4f4c8d7eddc49c164945701bbe00;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git some tests, should be OK --- diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 7dd7400ea6..4c0a008904 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -175,8 +175,9 @@ class REAL(sqltypes.Float): __visit_name__ = 'REAL' - def __init__(self): - super(REAL, self).__init__(precision=24) + def __init__(self, **kw): + kw.setdefault('precision', 24) + super(REAL, self).__init__(**kw) class TINYINT(sqltypes.Integer): __visit_name__ = 'TINYINT' @@ -258,7 +259,8 @@ class SMALLDATETIME(_DateTimeBase, sqltypes.DateTime): class DATETIME2(_DateTimeBase, sqltypes.DateTime): __visit_name__ = 'DATETIME2' - def __init__(self, precision=None, **kwargs): + def __init__(self, precision=None, **kw): + super(DATETIME2, self).__init__(**kw) self.precision = precision diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index deeebf0f90..fd99a16b55 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -771,7 +771,7 @@ class CHAR(_StringType, sqltypes.CHAR): __visit_name__ = 'CHAR' - def __init__(self, length, **kwargs): + def __init__(self, length=None, **kwargs): """Construct a CHAR. :param length: Maximum data length, in characters. diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 85ac3192ff..447938461e 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -131,7 +131,7 @@ class TypeEngine(AbstractType): else: return self.__class__ - def dialect_impl(self, dialect, **kwargs): + def dialect_impl(self, dialect): """Return a dialect-specific implementation for this type.""" try: @@ -149,22 +149,6 @@ class TypeEngine(AbstractType): d['bind'] = bp = d['impl'].bind_processor(dialect) return bp - def _dialect_info(self, dialect): - """Return a dialect-specific registry containing bind/result processors.""" - - if self in dialect._type_memos: - return dialect._type_memos[self] - else: - impl = self._gen_dialect_impl(dialect) - # the impl we put in here - # must not have any references to self. - if impl is self: - impl = self.adapt(type(self)) - dialect._type_memos[self] = d = { - 'impl':impl, - } - return d - def _cached_result_processor(self, dialect, coltype): """Return a dialect-specific result processor for this type.""" @@ -172,11 +156,28 @@ class TypeEngine(AbstractType): return dialect._type_memos[self][coltype] except KeyError: d = self._dialect_info(dialect) - # another key assumption. DBAPI type codes are - # constants. + # key assumption: DBAPI type codes are + # constants. Else this dictionary would + # grow unbounded. d[coltype] = rp = d['impl'].result_processor(dialect, coltype) return rp + def _dialect_info(self, dialect): + """Return a dialect-specific registry which + caches a dialect-specific implementation, bind processing + function, and one or more result processing functions.""" + + if self in dialect._type_memos: + return dialect._type_memos[self] + else: + impl = self._gen_dialect_impl(dialect) + if impl is self: + impl = self.adapt(type(self)) + # this can't be self, else we create a cycle + assert impl is not self + dialect._type_memos[self] = d = {'impl':impl} + return d + def _gen_dialect_impl(self, dialect): return dialect.type_descriptor(self) @@ -792,7 +793,7 @@ class String(Concatenable, TypeEngine): length=self.length, convert_unicode=self.convert_unicode, unicode_error=self.unicode_error, - _warn_on_bytestring=True, + _warn_on_bytestring=self._warn_on_bytestring, **kw ) @@ -1171,7 +1172,9 @@ class Float(Numeric): """ __visit_name__ = 'float' - + + scale = None + def __init__(self, precision=None, asdecimal=False, **kwargs): """ Construct a Float. @@ -1787,7 +1790,7 @@ class Interval(_DateAffinity, TypeDecorator): self.day_precision = day_precision def adapt(self, cls, **kw): - if self.native: + if self.native and hasattr(cls, '_adapt_from_generic_interval'): return cls._adapt_from_generic_interval(self, **kw) else: return cls(**kw) diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 3d9be543c7..f9307daafa 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -3,19 +3,44 @@ from test.lib.testing import eq_, assert_raises, assert_raises_message import decimal import datetime, os, re from sqlalchemy import * -from sqlalchemy import exc, types, util, schema +from sqlalchemy import exc, types, util, schema, dialects +for name in dialects.__all__: + __import__("sqlalchemy.dialects.%s" % name) from sqlalchemy.sql import operators, column, table from test.lib.testing import eq_ import sqlalchemy.engine.url as url -from sqlalchemy.databases import * from test.lib.schema import Table, Column from test.lib import * from test.lib.util import picklers from sqlalchemy.util.compat import decimal from test.lib.util import round_decimal - class AdaptTest(TestBase): + def _all_dialect_modules(self): + return [ + getattr(dialects, d) + for d in dialects.__all__ + if not d.startswith('_') + ] + + def _all_dialects(self): + return [d.base.dialect() for d in + self._all_dialect_modules()] + + def _all_types(self): + def types_for_mod(mod): + for key in dir(mod): + typ = getattr(mod, key) + if not isinstance(typ, type) or not issubclass(typ, types.TypeEngine): + continue + yield typ + + for typ in types_for_mod(types): + yield typ + for dialect in self._all_dialect_modules(): + for typ in types_for_mod(dialect): + yield typ + def test_uppercase_rendering(self): """Test that uppercase types from types.py always render as their type. @@ -27,12 +52,7 @@ class AdaptTest(TestBase): """ - for dialect in [ - oracle.dialect(), - mysql.dialect(), - postgresql.dialect(), - sqlite.dialect(), - mssql.dialect()]: + for dialect in self._all_dialects(): for type_, expected in ( (FLOAT, "FLOAT"), (NUMERIC, "NUMERIC"), @@ -49,7 +69,7 @@ class AdaptTest(TestBase): "NVARCHAR2(10)")), (CHAR, "CHAR"), (NCHAR, ("NCHAR", "NATIONAL CHAR")), - (BLOB, "BLOB"), + (BLOB, ("BLOB", "BLOB SUB_TYPE 0")), (BOOLEAN, ("BOOLEAN", "BOOL")) ): if isinstance(expected, str): @@ -65,7 +85,40 @@ class AdaptTest(TestBase): assert str(types.to_instance(type_)) in expected, \ "default str() of type %r not expected, %r" % \ (type_, expected) - + + @testing.uses_deprecated() + def test_adapt_method(self): + """ensure all types have a working adapt() method, + which creates a distinct copy. + + The distinct copy ensures that when we cache + the adapted() form of a type against the original + in a weak key dictionary, a cycle is not formed. + + This test doesn't test type-specific arguments of + adapt() beyond their defaults. + + """ + + for typ in self._all_types(): + if typ in (types.TypeDecorator, types.TypeEngine): + continue + elif typ is dialects.postgresql.ARRAY: + t1 = typ(String) + else: + t1 = typ() + for cls in [typ] + typ.__subclasses__(): + if not issubclass(typ, types.Enum) and \ + issubclass(cls, types.Enum): + continue + t2 = t1.adapt(cls) + assert t1 is not t2 + for k in t1.__dict__: + if k == 'impl': + continue + eq_(getattr(t2, k), t1.__dict__[k]) + + class TypeAffinityTest(TestBase): def test_type_affinity(self): for type_, affin in [ @@ -155,7 +208,7 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL): (Float(2), "FLOAT(2)", {'precision':4}), (Numeric(19, 2), "NUMERIC(19, 2)", {}), ]: - for dialect_ in (postgresql, mssql, mysql): + for dialect_ in (dialects.postgresql, dialects.mssql, dialects.mysql): dialect_ = dialect_.dialect() raw_impl = types.to_instance(impl_, **kw) @@ -188,8 +241,8 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL): else: return super(MyType, self).load_dialect_impl(dialect) - sl = sqlite.dialect() - pg = postgresql.dialect() + sl = dialects.sqlite.dialect() + pg = dialects.postgresql.dialect() t = MyType() self.assert_compile(t, "VARCHAR(50)", dialect=sl) self.assert_compile(t, "FLOAT", dialect=pg) @@ -1082,12 +1135,12 @@ class CompileTest(TestBase, AssertsCompiledSQL): for type_, expected in ( (String(), "VARCHAR"), (Integer(), "INTEGER"), - (postgresql.INET(), "INET"), - (postgresql.FLOAT(), "FLOAT"), - (mysql.REAL(precision=8, scale=2), "REAL(8, 2)"), - (postgresql.REAL(), "REAL"), + (dialects.postgresql.INET(), "INET"), + (dialects.postgresql.FLOAT(), "FLOAT"), + (dialects.mysql.REAL(precision=8, scale=2), "REAL(8, 2)"), + (dialects.postgresql.REAL(), "REAL"), (INTEGER(), "INTEGER"), - (mysql.INTEGER(display_width=5), "INTEGER(5)") + (dialects.mysql.INTEGER(display_width=5), "INTEGER(5)") ): self.assert_compile(type_, expected)