From: Mike Bayer Date: Wed, 14 Jan 2009 16:58:20 +0000 (+0000) Subject: more test fixup, type correction X-Git-Tag: rel_0_6_6~342 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=09478366b76665b36cfb50954da7c241bf5f1657;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git more test fixup, type correction --- diff --git a/lib/sqlalchemy/databases/__init__.py b/lib/sqlalchemy/databases/__init__.py index a824cd87b2..b45ea73d66 100644 --- a/lib/sqlalchemy/databases/__init__.py +++ b/lib/sqlalchemy/databases/__init__.py @@ -4,6 +4,10 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +from sqlalchemy.dialects.sqlite import base as sqlite +from sqlalchemy.dialects.postgres import base as postgres +from sqlalchemy.dialects.mysql import base as mysql + __all__ = ( 'access', @@ -11,6 +15,9 @@ __all__ = ( 'informix', 'maxdb', 'mssql', + 'mysql', + 'postgres', + 'sqlite', 'oracle', 'sybase', ) diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 9c2cf0352a..67c73efb71 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1127,7 +1127,7 @@ class MSSet(MSString): only the collation of character data. """ - self.__ddl_values = values + self._ddl_values = values strip_values = [] for a in values: @@ -1545,7 +1545,6 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, 'REAL') - def visit_FLOAT(self, type_): if self._mysql_type(type_) and type_.scale is not None and type_.precision is not None: return self._extend_numeric(type_, "FLOAT(%s, %s)" % (type_.precision, type_.scale)) @@ -1647,6 +1646,9 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self.visit_BLOB(type_) + def visit_binary(self, type_): + return self.visit_BLOB(type_) + def visit_BINARY(self, type_): if type_.length: return "BINARY(%d)" % type_.length @@ -1675,9 +1677,9 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): return self._extend_string(type_, "ENUM(%s)" % ",".join(quoted_enums)) def visit_SET(self, type_): - return self._extend_string("SET(%s)" % ",".join(type_._ddl_values)) + return self._extend_string(type_, "SET(%s)" % ",".join(type_._ddl_values)) - def visit_BOOL(self, type): + def visit_BOOLEAN(self, type): return "BOOL" @@ -1819,6 +1821,7 @@ class MySQLDialect(default.DefaultDialect): if rs: rs.close() + @engine_base.connection_memoize(('mysql', 'server_version_info')) def server_version_info(self, connection): """A tuple of the database server version. @@ -1831,9 +1834,9 @@ class MySQLDialect(default.DefaultDialect): cached per-Connection. """ + # TODO: do we need to bypass ConnectionFairy here? other calls + # to this seem to not do that. return self._server_version_info(connection.connection.connection) - server_version_info = engine_base.connection_memoize( - ('mysql', 'server_version_info'))(server_version_info) def reflecttable(self, connection, table, include_columns): """Load column definitions from the server.""" @@ -1850,7 +1853,7 @@ class MySQLDialect(default.DefaultDialect): # ANSI_QUOTES doesn't affect SHOW CREATE TABLE on < 4.1 preparer = MySQLIdentifierPreparer(self) - self.reflector = reflector = MySQLSchemaReflector(self) + self.reflector = reflector = MySQLSchemaReflector(self, preparer) sql = self._show_create_table(connection, table, charset) if sql.startswith('CREATE ALGORITHM'): @@ -2047,7 +2050,7 @@ class MySQLDialect(default.DefaultDialect): class MySQLSchemaReflector(object): """Parses SHOW CREATE TABLE output.""" - def __init__(self, dialect): + def __init__(self, dialect, preparer=None): """Construct a MySQLSchemaReflector. identifier_preparer @@ -2056,7 +2059,7 @@ class MySQLSchemaReflector(object): """ self.dialect = dialect - self.preparer = dialect.identifier_preparer + self.preparer = preparer or dialect.identifier_preparer self._prep_regexes() def reflect(self, connection, table, show_create, charset, only=None): diff --git a/lib/sqlalchemy/dialects/postgres/base.py b/lib/sqlalchemy/dialects/postgres/base.py index e9672eddf7..15ed21c77d 100644 --- a/lib/sqlalchemy/dialects/postgres/base.py +++ b/lib/sqlalchemy/dialects/postgres/base.py @@ -368,6 +368,7 @@ class PGDialect(default.DefaultDialect): supports_pk_autoincrement = False supports_default_values = True supports_empty_insert = False + default_paramstyle = 'pyformat' statement_compiler = PGCompiler ddl_compiler = PGDDLCompiler diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index d716a58ec4..ba08ccbb90 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -219,6 +219,7 @@ class SQLiteDialect(default.DefaultDialect): supports_default_values = True supports_empty_insert = False supports_cast = True + default_paramstyle = 'qmark' statement_compiler = SQLiteCompiler ddl_compiler = SQLiteDDLCompiler type_compiler = SQLiteTypeCompiler diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py index 6def864e89..4dd5ea2861 100644 --- a/lib/sqlalchemy/engine/__init__.py +++ b/lib/sqlalchemy/engine/__init__.py @@ -50,7 +50,9 @@ url.py within a URL. """ -import sqlalchemy.databases +# not sure what this was used for +#import sqlalchemy.databases + from sqlalchemy.engine.base import ( BufferedColumnResultProxy, BufferedColumnRow, diff --git a/test/dialect/mysql.py b/test/dialect/mysql.py index 0ca9240110..ad16de89a3 100644 --- a/test/dialect/mysql.py +++ b/test/dialect/mysql.py @@ -178,7 +178,7 @@ class TypesTest(TestBase, AssertsExecutionResults): table_args.append(Column('c%s' % index, type_(*args, **kw))) numeric_table = Table(*table_args) - gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None) + gen = testing.db.dialect.ddl_compiler(testing.db.dialect, numeric_table) for col in numeric_table.c: index = int(col.name[1:]) @@ -262,7 +262,7 @@ class TypesTest(TestBase, AssertsExecutionResults): table_args.append(Column('c%s' % index, type_(*args, **kw))) charset_table = Table(*table_args) - gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None) + gen = testing.db.dialect.ddl_compiler(testing.db.dialect, charset_table) for col in charset_table.c: index = int(col.name[1:]) @@ -741,7 +741,8 @@ class TypesTest(TestBase, AssertsExecutionResults): for table in tables: for i, reflected in enumerate(table.c): - assert isinstance(reflected.type, type(expected[i])) + assert isinstance(reflected.type, type(expected[i])), \ + "element %d: %r not instance of %r" % (i, reflected.type, type(expected[i])) finally: db.execute('DROP VIEW mysql_types_v') finally: @@ -986,8 +987,7 @@ class SQLTest(TestBase, AssertsCompiledSQL): class RawReflectionTest(TestBase): def setUp(self): self.dialect = mysql.dialect() - self.reflector = mysql.MySQLSchemaReflector( - self.dialect.identifier_preparer) + self.reflector = mysql.MySQLSchemaReflector(self.dialect) def test_key_reflection(self): regex = self.reflector._re_key @@ -1147,8 +1147,7 @@ class MatchTest(TestBase, AssertsCompiledSQL): def colspec(c): - return testing.db.dialect.schemagenerator(testing.db.dialect, - testing.db, None, None).get_column_specification(c) + return testing.db.dialect.ddl_compiler(testing.db.dialect, c.table).get_column_specification(c) if __name__ == "__main__": testenv.main() diff --git a/test/sql/select.py b/test/sql/select.py index 671ccab1a0..77112e649a 100644 --- a/test/sql/select.py +++ b/test/sql/select.py @@ -5,9 +5,7 @@ from sqlalchemy import exc, sql, util from sqlalchemy.sql import table, column, label, compiler from sqlalchemy.sql.expression import ClauseList from sqlalchemy.engine import default -from sqlalchemy.databases import mysql, oracle, firebird, mssql -from sqlalchemy.dialects.sqlite import pysqlite as sqlite -from sqlalchemy.dialects.postgres import psycopg2 as postgres +from sqlalchemy.databases import * from testlib import * table1 = table('mytable', @@ -1311,7 +1309,6 @@ UNION SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_2)") s1 = select([table1.c.myid, table1.c.myid.label('foobar'), func.hoho(table1.c.name), func.lala(table1.c.name).label('gg')]) assert s1.c.keys() == ['myid', 'foobar', 'hoho(mytable.name)', 'gg'] - from sqlalchemy.databases.sqlite import SLNumeric meta = MetaData() t1 = Table('mytable', meta, Column('col1', Integer)) @@ -1319,7 +1316,7 @@ UNION SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_2)") (table1.c.name, 'name', 'mytable.name', None), (table1.c.myid==12, 'mytable.myid = :myid_1', 'mytable.myid = :myid_1', 'anon_1'), (func.hoho(table1.c.myid), 'hoho(mytable.myid)', 'hoho(mytable.myid)', 'hoho_1'), - (cast(table1.c.name, SLNumeric), 'CAST(mytable.name AS NUMERIC(10, 2))', 'CAST(mytable.name AS NUMERIC(10, 2))', 'anon_1'), + (cast(table1.c.name, sqlite.SLNumeric), 'CAST(mytable.name AS NUMERIC(10, 2))', 'CAST(mytable.name AS NUMERIC(10, 2))', 'anon_1'), (t1.c.col1, 'col1', 'mytable.col1', None), (column('some wacky thing'), 'some wacky thing', '"some wacky thing"', '') ): diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index da649d0970..9ce7b7662a 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -6,9 +6,7 @@ from sqlalchemy import exc, types, util, schema from sqlalchemy.sql import operators from testlib.testing import eq_ import sqlalchemy.engine.url as url -from sqlalchemy.databases import mssql, oracle, mysql, firebird -from sqlalchemy.dialects.sqlite import pysqlite as sqlite -from sqlalchemy.dialects.postgres import psycopg2 as postgres +from sqlalchemy.databases import * from testlib import * diff --git a/test/testlib/engines.py b/test/testlib/engines.py index df1d37d3cd..85e1efa3a4 100644 --- a/test/testlib/engines.py +++ b/test/testlib/engines.py @@ -69,11 +69,10 @@ def close_open_connections(fn): def all_dialects(): import sqlalchemy.databases as d for name in d.__all__: - mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name) - yield mod.dialect() - import sqlalchemy.dialects as d - for name in d.__all__: - mod = getattr(__import__('sqlalchemy.dialects.%s.base' % name).dialects, name).base + # TEMPORARY + mod = getattr(d, name, None) + if not mod: + mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name) yield mod.dialect() class ReconnectFixture(object):