]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Merged lower case caching, fetching from r2955
authorJason Kirtland <jek@discorporate.us>
Fri, 20 Jul 2007 19:43:46 +0000 (19:43 +0000)
committerJason Kirtland <jek@discorporate.us>
Fri, 20 Jul 2007 19:43:46 +0000 (19:43 +0000)
Be sure to close rows fetched in reflection (if not autoclosed)
Fixed bind test, needed transactional storage engine for mysql

lib/sqlalchemy/databases/mysql.py
test/engine/bind.py

index f10eae7aeb718be6431fe7681a26c41de9b7a83a..bac0e5e12abb6cfdcf52373caa103060f2debb67 100644 (file)
@@ -4,15 +4,21 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-import re, datetime, inspect, warnings
+import re, datetime, inspect, warnings, weakref
 
-from sqlalchemy import sql,schema,ansisql
+from sqlalchemy import sql, schema, ansisql
 from sqlalchemy.engine import default
 import sqlalchemy.types as sqltypes
 import sqlalchemy.exceptions as exceptions
 import sqlalchemy.util as util
 from array import array as _array
 
+try:
+    from threading import Lock
+except ImportError:
+    from dummy_threading import Lock
+
+
 RESERVED_WORDS = util.Set(
     ['accessible', 'add', 'all', 'alter', 'analyze','and', 'as', 'asc',
      'asensitive', 'before', 'between', 'bigint', 'binary', 'blob', 'both',
@@ -55,6 +61,8 @@ RESERVED_WORDS = util.Set(
      'read_only', 'read_write', # 5.1
      ])
 
+_per_connection_mutex = Lock()
+
 class _NumericType(object):
     "Base for MySQL numeric types."
 
@@ -951,6 +959,7 @@ class MySQLExecutionContext(default.DefaultExecutionContext):
 class MySQLDialect(ansisql.ANSIDialect):
     def __init__(self, **kwargs):
         ansisql.ANSIDialect.__init__(self, default_paramstyle='format', **kwargs)
+        self.per_connection = weakref.WeakKeyDictionary()
 
     def dbapi(cls):
         import MySQLdb as mysql
@@ -1083,16 +1092,7 @@ class MySQLDialect(ansisql.ANSIDialect):
         """Load column definitions from the server."""
 
         decode_from = self._detect_charset(connection)
-
-        # reference:
-        # http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html
-        row = _compat_fetch(connection.execute(
-            "SHOW VARIABLES LIKE 'lower_case_table_names'"),
-                            one=True, charset=decode_from)
-        if not row:
-            case_sensitive = True
-        else:
-            case_sensitive = row[1] in ('0', 'OFF' 'off')
+        case_sensitive = self._detect_case_sensitive(connection, decode_from)
 
         if not case_sensitive:
             table.name = table.name.lower()
@@ -1105,7 +1105,7 @@ class MySQLDialect(ansisql.ANSIDialect):
                 raise exceptions.NoSuchTableError(table.fullname)
             raise
 
-        for row in _compat_fetch(rp, charset=decode_from):
+        for row in _compat_fetchall(rp, charset=decode_from):
             (name, type, nullable, primary_key, default) = \
                    (row[0], row[1], row[2] == 'YES', row[3] == 'PRI', row[4])
 
@@ -1152,10 +1152,11 @@ class MySQLDialect(ansisql.ANSIDialect):
         """SHOW CREATE TABLE to get foreign key/table options."""
 
         rp = connection.execute("SHOW CREATE TABLE " + self._escape_table_name(table), {})
-        row = _compat_fetch(rp, one=True, charset=charset)
+        row = _compat_fetchone(rp, charset=charset)
         if not row:
             raise exceptions.NoSuchTableError(table.fullname)
         desc = row[1].strip()
+        row.close()
 
         tabletype = ''
         lastparen = re.search(r'\)[^\)]*\Z', desc)
@@ -1188,7 +1189,7 @@ class MySQLDialect(ansisql.ANSIDialect):
         # on a connection via set_character_set()
         
         rs = connection.execute("show variables like 'character_set%%'")
-        opts = dict([(row[0], row[1]) for row in _compat_fetch(rs)])
+        opts = dict([(row[0], row[1]) for row in _compat_fetchall(rs)])
 
         if 'character_set_results' in opts:
             return opts['character_set_results']
@@ -1202,13 +1203,42 @@ class MySQLDialect(ansisql.ANSIDialect):
                 warnings.warn(RuntimeWarning("Could not detect the connection character set with this combination of MySQL server and MySQL-python.  MySQL-python >= 1.2.2 is recommended.  Assuming latin1."))
                 return 'latin1'
 
-def _compat_fetch(rp, one=False, charset=None):
+    def _detect_case_sensitive(self, connection, charset=None):
+        """Sniff out identifier case sensitivity.
+
+        Cached per-connection. This value can not change without a server
+        restart.
+        """
+        # http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html
+
+        _per_connection_mutex.acquire()
+        try:
+            raw_connection = connection.connection.connection
+            cache = self.per_connection.get(raw_connection, {})
+            if 'lower_case_table_names' not in cache:
+                row = _compat_fetchone(connection.execute(
+                        "SHOW VARIABLES LIKE 'lower_case_table_names'"),
+                        charset=charset)
+                if not row:
+                    cs = True
+                else:
+                    cs = row[1] in ('0', 'OFF' 'off')
+                    row.close()
+                cache['lower_case_table_names'] = cs
+                self.per_connection[raw_connection] = cache
+            return cache.get('lower_case_table_names')
+        finally:
+            _per_connection_mutex.release()
+
+def _compat_fetchall(rp, charset=None):
     """Proxy result rows to smooth over MySQL-Python driver inconsistencies."""
 
-    if one:
-        return _MySQLPythonRowProxy(rp.fetchone(), charset)
-    else:
-        return [_MySQLPythonRowProxy(row, charset) for row in rp.fetchall()]
+    return [_MySQLPythonRowProxy(row, charset) for row in rp.fetchall()]
+
+def _compat_fetchone(rp, charset=None):
+    """Proxy a result row to smooth over MySQL-Python driver inconsistencies."""
+
+    return _MySQLPythonRowProxy(rp.fetchone(), charset)
         
 
 class _MySQLPythonRowProxy(object):
index 33b3b0049ae28b51d43cf3cb7fe529f5f0cfe632..79457777a0ee827808bfa1ca72bd392a6ad1ad7b 100644 (file)
@@ -91,7 +91,7 @@ class BindTest(testbase.PersistTest):
             ):
                 metadata = MetaData(*args[0], **args[1])
                 table = Table('test_table', metadata,   
-                    Column('foo', Integer))
+                              Column('foo', Integer))
 
                 assert metadata.bind is metadata.engine is table.bind is table.engine is bind
                 metadata.create_all()
@@ -104,7 +104,8 @@ class BindTest(testbase.PersistTest):
     def test_implicit_execution(self):
         metadata = MetaData()
         table = Table('test_table', metadata,   
-            Column('foo', Integer))
+            Column('foo', Integer),
+            mysql_engine='InnoDB')
         conn = testbase.db.connect()
         metadata.create_all(bind=conn)
         try:
@@ -189,4 +190,4 @@ class BindTest(testbase.PersistTest):
         
                
 if __name__ == '__main__':
-    testbase.main()
\ No newline at end of file
+    testbase.main()