]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
more test fixup, type correction
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 14 Jan 2009 16:58:20 +0000 (16:58 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 14 Jan 2009 16:58:20 +0000 (16:58 +0000)
lib/sqlalchemy/databases/__init__.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/postgres/base.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/engine/__init__.py
test/dialect/mysql.py
test/sql/select.py
test/sql/testtypes.py
test/testlib/engines.py

index a824cd87b206064e5008fbcaccf8167b14be354f..b45ea73d665465c75b8b5b767cf0dac9f6baac63 100644 (file)
@@ -4,6 +4,10 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
+from sqlalchemy.dialects.sqlite import base as sqlite
+from sqlalchemy.dialects.postgres import base as postgres
+from sqlalchemy.dialects.mysql import base as mysql
+
 
 __all__ = (
     'access',
@@ -11,6 +15,9 @@ __all__ = (
     'informix',
     'maxdb',
     'mssql',
+    'mysql',
+    'postgres',
+    'sqlite',
     'oracle',
     'sybase',
     )
index 9c2cf0352a111934836bf2419b69e8aaeb7822b1..67c73efb7185a99cadb91539391946af5ca89fcf 100644 (file)
@@ -1127,7 +1127,7 @@ class MSSet(MSString):
           only the collation of character data.
 
         """
-        self.__ddl_values = values
+        self._ddl_values = values
 
         strip_values = []
         for a in values:
@@ -1545,7 +1545,6 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
         else:
             return self._extend_numeric(type_, 'REAL')
     
-    
     def visit_FLOAT(self, type_):
         if self._mysql_type(type_) and type_.scale is not None and type_.precision is not None:
             return self._extend_numeric(type_, "FLOAT(%s, %s)" % (type_.precision, type_.scale))
@@ -1647,6 +1646,9 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
         else:
             return self.visit_BLOB(type_)
     
+    def visit_binary(self, type_):
+        return self.visit_BLOB(type_)
+        
     def visit_BINARY(self, type_):
         if type_.length:
             return "BINARY(%d)" % type_.length
@@ -1675,9 +1677,9 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
         return self._extend_string(type_, "ENUM(%s)" % ",".join(quoted_enums))
         
     def visit_SET(self, type_):
-        return self._extend_string("SET(%s)" % ",".join(type_._ddl_values))
+        return self._extend_string(type_, "SET(%s)" % ",".join(type_._ddl_values))
 
-    def visit_BOOL(self, type):
+    def visit_BOOLEAN(self, type):
         return "BOOL"
         
 
@@ -1819,6 +1821,7 @@ class MySQLDialect(default.DefaultDialect):
             if rs:
                 rs.close()
 
+    @engine_base.connection_memoize(('mysql', 'server_version_info'))
     def server_version_info(self, connection):
         """A tuple of the database server version.
 
@@ -1831,9 +1834,9 @@ class MySQLDialect(default.DefaultDialect):
         cached per-Connection.
         """
 
+        # TODO: do we need to bypass ConnectionFairy here?  other calls
+        # to this seem to not do that.
         return self._server_version_info(connection.connection.connection)
-    server_version_info = engine_base.connection_memoize(
-        ('mysql', 'server_version_info'))(server_version_info)
 
     def reflecttable(self, connection, table, include_columns):
         """Load column definitions from the server."""
@@ -1850,7 +1853,7 @@ class MySQLDialect(default.DefaultDialect):
                 # ANSI_QUOTES doesn't affect SHOW CREATE TABLE on < 4.1
                 preparer = MySQLIdentifierPreparer(self)
 
-            self.reflector = reflector = MySQLSchemaReflector(self)
+            self.reflector = reflector = MySQLSchemaReflector(self, preparer)
 
         sql = self._show_create_table(connection, table, charset)
         if sql.startswith('CREATE ALGORITHM'):
@@ -2047,7 +2050,7 @@ class MySQLDialect(default.DefaultDialect):
 class MySQLSchemaReflector(object):
     """Parses SHOW CREATE TABLE output."""
 
-    def __init__(self, dialect):
+    def __init__(self, dialect, preparer=None):
         """Construct a MySQLSchemaReflector.
 
         identifier_preparer
@@ -2056,7 +2059,7 @@ class MySQLSchemaReflector(object):
         """
 
         self.dialect = dialect
-        self.preparer = dialect.identifier_preparer
+        self.preparer = preparer or dialect.identifier_preparer
         self._prep_regexes()
 
     def reflect(self, connection, table, show_create, charset, only=None):
index e9672eddf713e85bb565004508f9be5ff5849fb4..15ed21c77d58f3e00c9c355779bd4d3e5311e9e7 100644 (file)
@@ -368,6 +368,7 @@ class PGDialect(default.DefaultDialect):
     supports_pk_autoincrement = False
     supports_default_values = True
     supports_empty_insert = False
+    default_paramstyle = 'pyformat'
 
     statement_compiler = PGCompiler
     ddl_compiler = PGDDLCompiler
index d716a58ec4c7cf6e14dfbca46d958cba499cb007..ba08ccbb90fb73729cb493f74133b183e43a3900 100644 (file)
@@ -219,6 +219,7 @@ class SQLiteDialect(default.DefaultDialect):
     supports_default_values = True
     supports_empty_insert = False
     supports_cast = True
+    default_paramstyle = 'qmark'
     statement_compiler = SQLiteCompiler
     ddl_compiler = SQLiteDDLCompiler
     type_compiler = SQLiteTypeCompiler
index 6def864e89d076e6706ea47d299a40887b0138c4..4dd5ea28618dbd3dc001428c3a413b4fdbc30615 100644 (file)
@@ -50,7 +50,9 @@ url.py
     within a URL.
 """
 
-import sqlalchemy.databases
+# not sure what this was used for
+#import sqlalchemy.databases  
+
 from sqlalchemy.engine.base import (
     BufferedColumnResultProxy,
     BufferedColumnRow,
index 0ca9240110914459b38c9686e557405ed727482e..ad16de89a33b6d00d6fe0a6f8419d1ead25d40a7 100644 (file)
@@ -178,7 +178,7 @@ class TypesTest(TestBase, AssertsExecutionResults):
             table_args.append(Column('c%s' % index, type_(*args, **kw)))
 
         numeric_table = Table(*table_args)
-        gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None)
+        gen = testing.db.dialect.ddl_compiler(testing.db.dialect, numeric_table)
 
         for col in numeric_table.c:
             index = int(col.name[1:])
@@ -262,7 +262,7 @@ class TypesTest(TestBase, AssertsExecutionResults):
             table_args.append(Column('c%s' % index, type_(*args, **kw)))
 
         charset_table = Table(*table_args)
-        gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None)
+        gen = testing.db.dialect.ddl_compiler(testing.db.dialect, charset_table)
 
         for col in charset_table.c:
             index = int(col.name[1:])
@@ -741,7 +741,8 @@ class TypesTest(TestBase, AssertsExecutionResults):
 
                 for table in tables:
                     for i, reflected in enumerate(table.c):
-                        assert isinstance(reflected.type, type(expected[i]))
+                        assert isinstance(reflected.type, type(expected[i])), \
+                                "element %d: %r not instance of %r" % (i, reflected.type, type(expected[i]))
             finally:
                 db.execute('DROP VIEW mysql_types_v')
         finally:
@@ -986,8 +987,7 @@ class SQLTest(TestBase, AssertsCompiledSQL):
 class RawReflectionTest(TestBase):
     def setUp(self):
         self.dialect = mysql.dialect()
-        self.reflector = mysql.MySQLSchemaReflector(
-            self.dialect.identifier_preparer)
+        self.reflector = mysql.MySQLSchemaReflector(self.dialect)
 
     def test_key_reflection(self):
         regex = self.reflector._re_key
@@ -1147,8 +1147,7 @@ class MatchTest(TestBase, AssertsCompiledSQL):
 
 
 def colspec(c):
-    return testing.db.dialect.schemagenerator(testing.db.dialect,
-        testing.db, None, None).get_column_specification(c)
+    return testing.db.dialect.ddl_compiler(testing.db.dialect, c.table).get_column_specification(c)
 
 if __name__ == "__main__":
     testenv.main()
index 671ccab1a03d79d3073dc3c1e1e1a62b0cbe029e..77112e649a08a1cdc83f5545a8a5159b20cc383b 100644 (file)
@@ -5,9 +5,7 @@ from sqlalchemy import exc, sql, util
 from sqlalchemy.sql import table, column, label, compiler
 from sqlalchemy.sql.expression import ClauseList
 from sqlalchemy.engine import default
-from sqlalchemy.databases import mysql, oracle, firebird, mssql
-from sqlalchemy.dialects.sqlite import pysqlite as sqlite
-from sqlalchemy.dialects.postgres import psycopg2 as postgres
+from sqlalchemy.databases import *
 from testlib import *
 
 table1 = table('mytable',
@@ -1311,7 +1309,6 @@ UNION SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_2)")
         s1 = select([table1.c.myid, table1.c.myid.label('foobar'), func.hoho(table1.c.name), func.lala(table1.c.name).label('gg')])
         assert s1.c.keys() == ['myid', 'foobar', 'hoho(mytable.name)', 'gg']
 
-        from sqlalchemy.databases.sqlite import SLNumeric
         meta = MetaData()
         t1 = Table('mytable', meta, Column('col1', Integer))
         
@@ -1319,7 +1316,7 @@ UNION SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_2)")
             (table1.c.name, 'name', 'mytable.name', None),
             (table1.c.myid==12, 'mytable.myid = :myid_1', 'mytable.myid = :myid_1', 'anon_1'),
             (func.hoho(table1.c.myid), 'hoho(mytable.myid)', 'hoho(mytable.myid)', 'hoho_1'),
-            (cast(table1.c.name, SLNumeric), 'CAST(mytable.name AS NUMERIC(10, 2))', 'CAST(mytable.name AS NUMERIC(10, 2))', 'anon_1'),
+            (cast(table1.c.name, sqlite.SLNumeric), 'CAST(mytable.name AS NUMERIC(10, 2))', 'CAST(mytable.name AS NUMERIC(10, 2))', 'anon_1'),
             (t1.c.col1, 'col1', 'mytable.col1', None),
             (column('some wacky thing'), 'some wacky thing', '"some wacky thing"', '')
         ):
index da649d09703c0e64a0ff1f1256b44c3ab1a80a57..9ce7b7662ab1715aedfa5a0978e89721268e83d0 100644 (file)
@@ -6,9 +6,7 @@ from sqlalchemy import exc, types, util, schema
 from sqlalchemy.sql import operators
 from testlib.testing import eq_
 import sqlalchemy.engine.url as url
-from sqlalchemy.databases import mssql, oracle, mysql, firebird
-from sqlalchemy.dialects.sqlite import pysqlite as sqlite
-from sqlalchemy.dialects.postgres import psycopg2 as postgres
+from sqlalchemy.databases import *
 
 from testlib import *
 
index df1d37d3cd4fec647ac457965515add9647a47dc..85e1efa3a4e8420a9c6782522c483a53c7dcd552 100644 (file)
@@ -69,11 +69,10 @@ def close_open_connections(fn):
 def all_dialects():
     import sqlalchemy.databases as d
     for name in d.__all__:
-        mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name)
-        yield mod.dialect()
-    import sqlalchemy.dialects as d
-    for name in d.__all__:
-        mod = getattr(__import__('sqlalchemy.dialects.%s.base' % name).dialects, name).base
+        # TEMPORARY
+        mod = getattr(d, name, None)
+        if not mod:
+            mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name)
         yield mod.dialect()
         
 class ReconnectFixture(object):