]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- connection initialize moves to a connection pool event [ticket:1340]
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 21 Mar 2009 19:34:45 +0000 (19:34 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 21 Mar 2009 19:34:45 +0000 (19:34 +0000)
- sqlite doesn't support schemas.  not sure if some versions do, but marking those as unsupported for now.
- added a testing.requires callable for schema support.
- standardized the "extra schema" name for unit tests as "test_schema" and "test_schema_2".
- sqlite needs description_encoding (was some other version of pysqlite tested here ?)
- other test fixes.

lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/mysql/mysqldb.py
lib/sqlalchemy/dialects/sqlite/pysqlite.py
lib/sqlalchemy/engine/strategies.py
test/dialect/postgres.py
test/engine/parseconnect.py
test/engine/reconnect.py
test/engine/reflection.py
test/testlib/requires.py

index df6d09d2abb28d5a513c3d78fefa6bfe1d00194c..cbb79662da9ea52222c058b745be86e081a4eb34 100644 (file)
@@ -1080,9 +1080,6 @@ class MSDialect(default.DefaultDialect):
                 self.max_identifier_length
         super(MSDialect, self).__init__(**opts)
     
-    def initialize(self, connection):
-        self.server_version_info = self._get_server_version_info(connection)
-    
     def do_begin(self, connection):
         cursor = connection.cursor()
         cursor.execute("SET IMPLICIT_TRANSACTIONS OFF")
index 627ed84a1d23eba5318f83bf7efb529cdd3f990e..9dd2bfe7151978d4421c1956a5f04c2c0e4450ae 100644 (file)
@@ -1778,11 +1778,10 @@ class MySQLDialect(default.DefaultDialect):
 
     def _extract_error_code(self, exception):
         raise NotImplementedError()
-        
+    
+    @engine_base.connection_memoize(('dialect', 'default_schema_name'))
     def get_default_schema_name(self, connection):
         return connection.execute('SELECT DATABASE()').scalar()
-    get_default_schema_name = engine_base.connection_memoize(
-        ('dialect', 'default_schema_name'))(get_default_schema_name)
 
     def table_names(self, connection, schema):
         """Return a Unicode SHOW TABLES from a given schema."""
index a67c42ff843d86a43bcdf66f02c8c664d0124102..5f7636bba94ff39cfef0647d604f9a45b7a92a5f 100644 (file)
@@ -115,10 +115,6 @@ class MySQL_mysqldb(MySQLDialect):
     def _detect_charset(self, connection):
         """Sniff out the character set in use for connection results."""
 
-        # Allow user override, won't sniff if force_charset is set.
-        if ('mysql', 'force_charset') in connection.info:
-            return connection.info[('mysql', 'force_charset')]
-
         # Note: MySQL-python 1.2.1c7 seems to ignore changes made
         # on a connection via set_character_set()
         if self.server_version_info < (4, 1, 0):
index ddd836795c8df774214427794948dd2f0edc64fe..13e5c16f5c4c5baa6eb63927dcc36055bdb5d58f 100644 (file)
@@ -121,7 +121,6 @@ class SQLite_pysqlite(SQLiteDialect):
     default_paramstyle = 'qmark'
     poolclass = pool.SingletonThreadPool
     execution_ctx_cls = SQLite_pysqliteExecutionContext
-    description_encoding = None
     driver = 'pysqlite'
     
     def __init__(self, **kwargs):
index 75409c10f7e167001c65241b5b3ee7a15e899976..5187ab1927ad0eb22f6256490dd93be6f3b23be1 100644 (file)
@@ -13,7 +13,7 @@ from operator import attrgetter
 from sqlalchemy.engine import base, threadlocal, url
 from sqlalchemy import util, exc
 from sqlalchemy import pool as poollib
-
+from sqlalchemy import interfaces
 
 strategies = {}
 
@@ -123,13 +123,15 @@ class DefaultEngineStrategy(EngineStrategy):
                                     engineclass.__name__))
                                     
         engine = engineclass(pool, dialect, u, **engine_args)
-        
+
         if _initialize:
-            conn = engine.connect()
-            try:
-                dialect.initialize(conn)
-            finally:
-                conn.close()
+            class OnInit(interfaces.PoolListener):
+                def connect(self, conn, rec):
+                    c = base.Connection(engine, connection=conn)
+                    dialect.initialize(c)
+                    pool._on_connect.remove(self)
+            pool._on_connect.insert(0, OnInit())
+        
         return engine
 
     def pool_threadlocal(self):
index f4807e18e6c7dd38a293984a9fc976c876c26071..6e6500af152173291bd4150fb5537461381b05fe 100644 (file)
@@ -396,23 +396,23 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults):
     def setUpAll(self):
         con = testing.db.connect()
         for ddl in ('CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42',
-                    'CREATE DOMAIN alt_schema.testdomain INTEGER DEFAULT 0'):
+                    'CREATE DOMAIN test_schema.testdomain INTEGER DEFAULT 0'):
             try:
                 con.execute(ddl)
             except exc.SQLError, e:
                 if not "already exists" in str(e):
                     raise e
         con.execute('CREATE TABLE testtable (question integer, answer testdomain)')
-        con.execute('CREATE TABLE alt_schema.testtable(question integer, answer alt_schema.testdomain, anything integer)')
-        con.execute('CREATE TABLE crosschema (question integer, answer alt_schema.testdomain)')
+        con.execute('CREATE TABLE test_schema.testtable(question integer, answer test_schema.testdomain, anything integer)')
+        con.execute('CREATE TABLE crosschema (question integer, answer test_schema.testdomain)')
 
     def tearDownAll(self):
         con = testing.db.connect()
         con.execute('DROP TABLE testtable')
-        con.execute('DROP TABLE alt_schema.testtable')
+        con.execute('DROP TABLE test_schema.testtable')
         con.execute('DROP TABLE crosschema')
         con.execute('DROP DOMAIN testdomain')
-        con.execute('DROP DOMAIN alt_schema.testdomain')
+        con.execute('DROP DOMAIN test_schema.testdomain')
 
     def test_table_is_reflected(self):
         metadata = MetaData(testing.db)
@@ -426,15 +426,15 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults):
         self.assertEquals(str(table.columns.answer.server_default.arg), '42', "Reflected default value didn't equal expected value")
         self.assertFalse(table.columns.answer.nullable, "Expected reflected column to not be nullable.")
 
-    def test_table_is_reflected_alt_schema(self):
+    def test_table_is_reflected_test_schema(self):
         metadata = MetaData(testing.db)
-        table = Table('testtable', metadata, autoload=True, schema='alt_schema')
+        table = Table('testtable', metadata, autoload=True, schema='test_schema')
         self.assertEquals(set(table.columns.keys()), set(['question', 'answer', 'anything']), "Columns of reflected table didn't equal expected columns")
         assert isinstance(table.c.anything.type, Integer)
 
     def test_schema_domain_is_reflected(self):
         metadata = MetaData(testing.db)
-        table = Table('testtable', metadata, autoload=True, schema='alt_schema')
+        table = Table('testtable', metadata, autoload=True, schema='test_schema')
         self.assertEquals(str(table.columns.answer.server_default.arg), '0', "Reflected default value didn't equal expected value")
         self.assertTrue(table.columns.answer.nullable, "Expected reflected column to be nullable.")
 
@@ -529,26 +529,26 @@ class MiscTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
             'FROM mytable')
 
     def test_schema_reflection(self):
-        """note: this test requires that the 'alt_schema' schema be separate and accessible by the test user"""
+        """note: this test requires that the 'test_schema' schema be separate and accessible by the test user"""
 
         meta1 = MetaData(testing.db)
         users = Table('users', meta1,
             Column('user_id', Integer, primary_key = True),
             Column('user_name', String(30), nullable = False),
-            schema="alt_schema"
+            schema="test_schema"
             )
 
         addresses = Table('email_addresses', meta1,
             Column('address_id', Integer, primary_key = True),
             Column('remote_user_id', Integer, ForeignKey(users.c.user_id)),
             Column('email_address', String(20)),
-            schema="alt_schema"
+            schema="test_schema"
         )
         meta1.create_all()
         try:
             meta2 = MetaData(testing.db)
-            addresses = Table('email_addresses', meta2, autoload=True, schema="alt_schema")
-            users = Table('users', meta2, mustexist=True, schema="alt_schema")
+            addresses = Table('email_addresses', meta2, autoload=True, schema="test_schema")
+            users = Table('users', meta2, mustexist=True, schema="test_schema")
 
             print users
             print addresses
@@ -567,12 +567,12 @@ class MiscTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
         referer = Table("referer", meta1,
                         Column("id", Integer, primary_key=True),
                         Column("ref", Integer, ForeignKey('subject.id')),
-                        schema="alt_schema")
+                        schema="test_schema")
         meta1.create_all()
         try:
             meta2 = MetaData(testing.db)
             subject = Table("subject", meta2, autoload=True)
-            referer = Table("referer", meta2, schema="alt_schema", autoload=True)
+            referer = Table("referer", meta2, schema="test_schema", autoload=True)
             print str(subject.join(referer).onclause)
             self.assert_((subject.c.id==referer.c.ref).compare(subject.join(referer).onclause))
         finally:
@@ -582,19 +582,19 @@ class MiscTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
         meta1 = MetaData(testing.db)
         subject = Table("subject", meta1,
                         Column("id", Integer, primary_key=True),
-                        schema='alt_schema_2'
+                        schema='test_schema_2'
                         )
 
         referer = Table("referer", meta1,
                         Column("id", Integer, primary_key=True),
-                        Column("ref", Integer, ForeignKey('alt_schema_2.subject.id')),
-                        schema="alt_schema")
+                        Column("ref", Integer, ForeignKey('test_schema_2.subject.id')),
+                        schema="test_schema")
 
         meta1.create_all()
         try:
             meta2 = MetaData(testing.db)
-            subject = Table("subject", meta2, autoload=True, schema="alt_schema_2")
-            referer = Table("referer", meta2, schema="alt_schema", autoload=True)
+            subject = Table("subject", meta2, autoload=True, schema="test_schema_2")
+            referer = Table("referer", meta2, schema="test_schema", autoload=True)
             print str(subject.join(referer).onclause)
             self.assert_((subject.c.id==referer.c.ref).compare(subject.join(referer).onclause))
         finally:
@@ -604,7 +604,7 @@ class MiscTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
         meta = MetaData(testing.db)
         users = Table('users', meta,
             Column('id', Integer, primary_key=True),
-            Column('name', String(50)), schema='alt_schema')
+            Column('name', String(50)), schema='test_schema')
         users.create()
         try:
             users.insert().execute(id=1, name='name1')
index 4a6ca90d1444e249e7aa29dc0dc3869693ffa19b..327a470a591f1e0ce9c2363f31db8b4673b7215d 100644 (file)
@@ -42,15 +42,22 @@ class CreateEngineTest(TestBase):
     def test_connect_query(self):
         dbapi = MockDBAPI(foober='12', lala='18', fooz='somevalue')
 
-        # start the postgres dialect, but put our mock DBAPI as the module instead of psycopg
-        e = create_engine('postgres://scott:tiger@somehost/test?foober=12&lala=18&fooz=somevalue', module=dbapi)
+        e = create_engine(
+                'postgres://scott:tiger@somehost/test?foober=12&lala=18&fooz=somevalue', 
+                module=dbapi,
+                _initialize=False
+                )
         c = e.connect()
 
     def test_kwargs(self):
         dbapi = MockDBAPI(foober=12, lala=18, hoho={'this':'dict'}, fooz='somevalue')
 
-        # start the postgres dialect, but put our mock DBAPI as the module instead of psycopg
-        e = create_engine('postgres://scott:tiger@somehost/test?fooz=somevalue', connect_args={'foober':12, 'lala':18, 'hoho':{'this':'dict'}}, module=dbapi)
+        e = create_engine(
+                'postgres://scott:tiger@somehost/test?fooz=somevalue', 
+                connect_args={'foober':12, 'lala':18, 'hoho':{'this':'dict'}}, 
+                module=dbapi,
+                _initialize=False
+                )
         c = e.connect()
 
     def test_coerce_config(self):
@@ -118,7 +125,7 @@ pool_timeout=10
             return dbapi.connect(foober=12, lala=18, fooz='somevalue', hoho={'this':'dict'})
 
         # start the postgres dialect, but put our mock DBAPI as the module instead of psycopg
-        e = create_engine('postgres://', creator=connect, module=dbapi)
+        e = create_engine('postgres://', creator=connect, module=dbapi, _initialize=False)
         c = e.connect()
 
     def test_recycle(self):
@@ -145,7 +152,8 @@ pool_timeout=10
         self.assertRaises(TypeError, create_engine, 'sqlite://', max_overflow=5)
 
         # raises DBAPIerror due to use_unicode not a sqlite arg
-        self.assertRaises(tsa.exc.DBAPIError, create_engine, 'sqlite://', connect_args={'use_unicode':True}, convert_unicode=True)
+        e = create_engine('sqlite://', connect_args={'use_unicode':True}, convert_unicode=True)
+        self.assertRaises(tsa.exc.DBAPIError, e.connect)
 
     def test_urlattr(self):
         """test the url attribute on ``Engine``."""
index 10c80e13526083eda8ffbb4a685bb4dd0572123b..a1b08c6a390881c7c37e6f71a520293c030ea900 100644 (file)
@@ -52,7 +52,7 @@ class MockReconnectTest(TestBase):
         dbapi = MockDBAPI()
 
         # create engine using our current dburi
-        db = tsa.create_engine('postgres://foo:bar@localhost/test', module=dbapi)
+        db = tsa.create_engine('postgres://foo:bar@localhost/test', module=dbapi, _initialize=False)
 
         # monkeypatch disconnect checker
         db.dialect.is_disconnect = lambda e: isinstance(e, MockDisconnect)
index 8e80a2898f487ff795a12b0bffc9d38658bfad95..f133638fd5340b7ab6856deaef3de2aa30250bfa 100644 (file)
@@ -8,11 +8,6 @@ from testlib import TestBase, ComparesTables, testing, engines, sa as tsa
 
 create_inspector = Inspector.from_engine
 
-# Py2K
-if 'set' not in dir(__builtins__):
-    from sets import Set as set
-
-# end Py2K
 metadata, users = None, None
 
 class ReflectionTest(TestBase, ComparesTables):
@@ -805,12 +800,9 @@ class HasSequenceTest(TestBase):
 
 # Tests related to engine.reflection
 
-def getSchema():
-    if testing.against('sqlite'):
-        return None
-
-    if testing.against('sqlite'):
-        return 'main'
+def get_schema():
+#    if testing.against('sqlite'):
+#        return None
     if testing.against('oracle'):
         return 'test'
     else:
@@ -884,10 +876,12 @@ def dropViews(con, schema=None):
 
 class ReflectionTest(TestBase):
 
+    @testing.fails_on('sqlite', 'no schemas')
     def test_get_schema_names(self):
         meta = MetaData(testing.db)
         insp = Inspector(meta.bind)
-        self.assert_(getSchema() in insp.get_schema_names())
+        
+        self.assert_(get_schema() in insp.get_schema_names())
 
     def _test_get_table_names(self, schema=None, table_type='table',
                               order_by=None):
@@ -918,14 +912,16 @@ class ReflectionTest(TestBase):
     def test_get_table_names(self):
         self._test_get_table_names()
 
+    @testing.requires.schemas
     def test_get_table_names_with_schema(self):
-        self._test_get_table_names(getSchema())
+        self._test_get_table_names(get_schema())
 
     def test_get_view_names(self):
         self._test_get_table_names(table_type='view')
 
+    @testing.requires.schemas
     def test_get_view_names_with_schema(self):
-        self._test_get_table_names(getSchema(), table_type='view')
+        self._test_get_table_names(get_schema(), table_type='view')
 
     def _test_get_columns(self, schema=None, table_type='table'):
         meta = MetaData(testing.db)
@@ -974,14 +970,16 @@ class ReflectionTest(TestBase):
     def test_get_columns(self):
         self._test_get_columns()
 
+    @testing.requires.schemas
     def test_get_columns_with_schema(self):
-        self._test_get_columns(schema=getSchema())
+        self._test_get_columns(schema=get_schema())
 
     def test_get_view_columns(self):
         self._test_get_columns(table_type='view')
 
+    @testing.requires.schemas
     def test_get_view_columns_with_schema(self):
-        self._test_get_columns(schema=getSchema(), table_type='view')
+        self._test_get_columns(schema=get_schema(), table_type='view')
 
     def _test_get_primary_keys(self, schema=None):
         meta = MetaData(testing.db)
@@ -1003,8 +1001,9 @@ class ReflectionTest(TestBase):
     def test_get_primary_keys(self):
         self._test_get_primary_keys()
 
+    @testing.fails_on('sqlite', 'no schemas')
     def test_get_primary_keys_with_schema(self):
-        self._test_get_primary_keys(schema=getSchema())
+        self._test_get_primary_keys(schema=get_schema())
 
     def _test_get_foreign_keys(self, schema=None):
         meta = MetaData(testing.db)
@@ -1044,8 +1043,9 @@ class ReflectionTest(TestBase):
     def test_get_foreign_keys(self):
         self._test_get_foreign_keys()
 
+    @testing.requires.schemas
     def test_get_foreign_keys_with_schema(self):
-        self._test_get_foreign_keys(schema=getSchema())
+        self._test_get_foreign_keys(schema=get_schema())
 
     def _test_get_indexes(self, schema=None):
         meta = MetaData(testing.db)
@@ -1082,8 +1082,9 @@ class ReflectionTest(TestBase):
     def test_get_indexes(self):
         self._test_get_indexes()
 
+    @testing.requires.schemas
     def test_get_indexes_with_schema(self):
-        self._test_get_indexes(schema=getSchema())
+        self._test_get_indexes(schema=get_schema())
 
     def _test_get_view_definition(self, schema=None):
         meta = MetaData(testing.db)
@@ -1106,8 +1107,9 @@ class ReflectionTest(TestBase):
     def test_get_view_definition(self):
         self._test_get_view_definition()
 
+    @testing.requires.schemas
     def test_get_view_definition_with_schema(self):
-        self._test_get_view_definition(schema=getSchema())
+        self._test_get_view_definition(schema=get_schema())
 
     def _test_get_table_oid(self, table_name, schema=None):
         if testing.against('postgres'):
@@ -1125,8 +1127,9 @@ class ReflectionTest(TestBase):
     def test_get_table_oid(self):
         self._test_get_table_oid('users')
 
+    @testing.requires.schemas
     def test_get_table_oid_with_schema(self):
-        self._test_get_table_oid('users', schema=getSchema())
+        self._test_get_table_oid('users', schema=get_schema())
 
 
 if __name__ == "__main__":
index 4ccce962050af2cff9cc2948755c644ccb454414..1c66e3e51b81976a950f0dd3a9d7f19f1d6ed766 100644 (file)
@@ -74,6 +74,14 @@ def savepoints(fn):
         exclude('mysql', '<', (5, 0, 3), 'not supported by database'),
         )
 
+def schemas(fn):
+    """Target database must support external schemas, and have one named 'test_schema'."""
+    
+    return _chain_decorators_on(
+        fn,
+        no_support('sqlite', 'no schema support')
+    )
+    
 def sequences(fn):
     """Target database must support SEQUENCEs."""
     return _chain_decorators_on(