]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Merged r2945, r2946, r2947 from trunk
authorJason Kirtland <jek@discorporate.us>
Wed, 18 Jul 2007 07:34:43 +0000 (07:34 +0000)
committerJason Kirtland <jek@discorporate.us>
Wed, 18 Jul 2007 07:34:43 +0000 (07:34 +0000)
- Cache 'lower_case_table_names' test for the lifetime of a connection
- Clean up compat fetch stuff

lib/sqlalchemy/databases/mysql.py
test/engine/reflection.py

index 3e03109b78eeaf90ec9ee3dd90104b19885f41ce..3ff46e09425107eadfea17533a9ebd5fad583aa2 100644 (file)
@@ -4,14 +4,19 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-import sys, StringIO, string, types, re, datetime, inspect, warnings
+import re, datetime, inspect, warnings, weakref
 
-from sqlalchemy import sql,engine,schema,ansisql
+from sqlalchemy import sql, schema, ansisql
 from sqlalchemy.engine import default
 import sqlalchemy.types as sqltypes
 import sqlalchemy.exceptions as exceptions
 import sqlalchemy.util as util
-from array import array
+from array import array as _array
+
+try:
+    from threading import Lock
+except ImportError:
+    from dummy_threading import Lock
 
 RESERVED_WORDS = util.Set(
     ['accessible', 'add', 'all', 'alter', 'analyze','and', 'as', 'asc',
@@ -54,6 +59,7 @@ RESERVED_WORDS = util.Set(
      'accessible', 'linear', 'master_ssl_verify_server_cert', 'range',
      'read_only', 'read_write', # 5.1
      ])
+_per_connection_mutex = Lock()
 
 class _NumericType(object):
     "Base for MySQL numeric types."
@@ -954,6 +960,7 @@ class MySQLExecutionContext(default.DefaultExecutionContext):
 class MySQLDialect(ansisql.ANSIDialect):
     def __init__(self, **kwargs):
         ansisql.ANSIDialect.__init__(self, default_paramstyle='format', **kwargs)
+        self.per_connection = weakref.WeakKeyDictionary()
 
     def dbapi(cls):
         import MySQLdb as mysql
@@ -1064,7 +1071,7 @@ class MySQLDialect(ansisql.ANSIDialect):
 
     def get_default_schema_name(self):
         if not hasattr(self, '_default_schema_name'):
-            self._default_schema_name = text("select database()", self).scalar()
+            self._default_schema_name = sql.text("select database()", self).scalar()
         return self._default_schema_name
 
     def has_table(self, connection, table_name, schema=None):
@@ -1101,28 +1108,20 @@ class MySQLDialect(ansisql.ANSIDialect):
         """Load column definitions from the server."""
 
         decode_from = self._detect_charset(connection)
-
-        # reference:
-        # http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html
-        row = _compat_fetch(connection.execute(
-            "SHOW VARIABLES LIKE 'lower_case_table_names'"),
-                            one=True, charset=decode_from)
-        if not row:
-            case_sensitive = True
-        else:
-            case_sensitive = row[1] in ('0', 'OFF' 'off')
+        case_sensitive = self._detect_case_sensitive(connection, decode_from)
 
         if not case_sensitive:
             table.name = table.name.lower()
             table.metadata.tables[table.name]= table
 
         try:
-            rp = connection.execute("describe " + self._escape_table_name(table),
-                                   {})
-        except:
-            raise exceptions.NoSuchTableError(table.fullname)
+            rp = connection.execute("DESCRIBE " + self._escape_table_name(table))
+        except exceptions.SQLError, e:
+            if e.orig.args[0] == 1146:
+                raise exceptions.NoSuchTableError(table.fullname)
+            raise
 
-        for row in _compat_fetch(rp, charset=decode_from):
+        for row in _compat_fetchall(rp, charset=decode_from):
             (name, type, nullable, primary_key, default) = \
                    (row[0], row[1], row[2] == 'YES', row[3] == 'PRI', row[4])
 
@@ -1173,7 +1172,7 @@ class MySQLDialect(ansisql.ANSIDialect):
         """SHOW CREATE TABLE to get foreign key/table options."""
 
         rp = connection.execute("SHOW CREATE TABLE " + self._escape_table_name(table), {})
-        row = _compat_fetch(rp, one=True, charset=charset)
+        row = _compat_fetchone(rp, charset=charset)
         if not row:
             raise exceptions.NoSuchTableError(table.fullname)
         desc = row[1].strip()
@@ -1198,7 +1197,7 @@ class MySQLDialect(ansisql.ANSIDialect):
 
     def _escape_table_name(self, table):
         if table.schema is not None:
-            return '`%s`.`%s`' % (table.schema. table.name)
+            return '`%s`.`%s`' % (table.schema, table.name)
         else:
             return '`%s`' % table.name
 
@@ -1209,7 +1208,7 @@ class MySQLDialect(ansisql.ANSIDialect):
         # on a connection via set_character_set()
         
         rs = connection.execute("show variables like 'character_set%%'")
-        opts = dict([(row[0], row[1]) for row in _compat_fetch(rs)])
+        opts = dict([(row[0], row[1]) for row in _compat_fetchall(rs)])
 
         if 'character_set_results' in opts:
             return opts['character_set_results']
@@ -1223,13 +1222,41 @@ class MySQLDialect(ansisql.ANSIDialect):
                 warnings.warn(RuntimeWarning("Could not detect the connection character set with this combination of MySQL server and MySQL-python.  MySQL-python >= 1.2.2 is recommended.  Assuming latin1."))
                 return 'latin1'
 
-def _compat_fetch(rp, one=False, charset=None):
+    def _detect_case_sensitive(self, connection, charset=None):
+        """Sniff out identifier case sensitivity.
+
+        Cached per-connection. This value can not change without a server
+        restart.
+        """
+        # http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html
+
+        _per_connection_mutex.acquire()
+        try:
+            raw_connection = connection.connection.connection
+            cache = self.per_connection.get(raw_connection, {})
+            if 'lower_case_table_names' not in cache:
+                row = _compat_fetchone(connection.execute(
+                        "SHOW VARIABLES LIKE 'lower_case_table_names'"),
+                        charset=charset)
+                if not row:
+                    cs = True
+                else:
+                    cs = row[1] in ('0', 'OFF' 'off')
+                cache['lower_case_table_names'] = cs
+                self.per_connection[raw_connection] = cache
+            return cache.get('lower_case_table_names')
+        finally:
+            _per_connection_mutex.release()
+
+def _compat_fetchall(rp, charset=None):
     """Proxy result rows to smooth over MySQL-Python driver inconsistencies."""
 
-    if one:
-        return _MySQLPythonRowProxy(rp.fetchone(), charset)
-    else:
-        return [_MySQLPythonRowProxy(row, charset) for row in rp.fetchall()]
+    return [_MySQLPythonRowProxy(row, charset) for row in rp.fetchall()]
+
+def _compat_fetchone(rp, charset=None):
+    """Proxy a result row to smooth over MySQL-Python driver inconsistencies."""
+
+    return _MySQLPythonRowProxy(rp.fetchone(), charset)
         
 
 class _MySQLPythonRowProxy(object):
@@ -1240,7 +1267,7 @@ class _MySQLPythonRowProxy(object):
         self.charset = charset
     def __getitem__(self, index):
         item = self.rowproxy[index]
-        if isinstance(item, array):
+        if isinstance(item, _array):
             item = item.tostring()
         if self.charset and isinstance(item, unicode):
             return item.encode(self.charset)
@@ -1248,7 +1275,7 @@ class _MySQLPythonRowProxy(object):
             return item
     def __getattr__(self, attr):
         item = getattr(self.rowproxy, attr)
-        if isinstance(item, array):
+        if isinstance(item, _array):
             item = item.tostring()
         if self.charset and isinstance(item, unicode):
             return item.encode(self.charset)
index 85701599e8b4a152514ebb7e4f38e262bc0057f1..d2ee3106cb1038be965ee3c9cae17f9f1fdfb766 100644 (file)
@@ -428,10 +428,6 @@ class ReflectionTest(PersistTest):
         finally:
             meta.drop_all(testbase.db)
             
-    # mysql throws its own exception for no such table, resulting in 
-    # a sqlalchemy.SQLError instead of sqlalchemy.NoSuchTableError.
-    # this could probably be fixed at some point.
-    @testbase.unsupported('mysql')    
     def test_nonexistent(self):
         self.assertRaises(NoSuchTableError, Table,
                           'fake_table',
@@ -583,6 +579,23 @@ class SchemaTest(PersistTest):
         assert buf.index("CREATE TABLE someschema.table1") > -1
         assert buf.index("CREATE TABLE someschema.table2") > -1
     
+    @testbase.unsupported('sqlite')
+    def testcreate(self):
+        schema = testbase.db.url.database
+        metadata = MetaData(testbase.db)
+        table1 = Table('table1', metadata, 
+            Column('col1', Integer, primary_key=True),
+            schema=schema)
+        table2 = Table('table2', metadata, 
+            Column('col1', Integer, primary_key=True),
+            Column('col2', Integer, ForeignKey('%s.table1.col1' % schema)),
+            schema=schema)
+        metadata.create_all()
+        metadata.create_all(checkfirst=True)
+        metadata.clear()
+        table1 = Table('table1', metadata, autoload=True, schema=schema)
+        table2 = Table('table2', metadata, autoload=True, schema=schema)
+        metadata.drop_all()
         
 if __name__ == "__main__":
     testbase.main()