From: Jason Kirtland Date: Wed, 18 Jul 2007 07:34:43 +0000 (+0000) Subject: - Merged r2945, r2946, r2947 from trunk X-Git-Tag: rel_0_4_6~85 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d0910b22755a699f71eb549dc09a9f9423580d5c;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Merged r2945, r2946, r2947 from trunk - Cache 'lower_case_table_names' test for the lifetime of a connection - Clean up compat fetch stuff --- diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 3e03109b78..3ff46e0942 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -4,14 +4,19 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import sys, StringIO, string, types, re, datetime, inspect, warnings +import re, datetime, inspect, warnings, weakref -from sqlalchemy import sql,engine,schema,ansisql +from sqlalchemy import sql, schema, ansisql from sqlalchemy.engine import default import sqlalchemy.types as sqltypes import sqlalchemy.exceptions as exceptions import sqlalchemy.util as util -from array import array +from array import array as _array + +try: + from threading import Lock +except ImportError: + from dummy_threading import Lock RESERVED_WORDS = util.Set( ['accessible', 'add', 'all', 'alter', 'analyze','and', 'as', 'asc', @@ -54,6 +59,7 @@ RESERVED_WORDS = util.Set( 'accessible', 'linear', 'master_ssl_verify_server_cert', 'range', 'read_only', 'read_write', # 5.1 ]) +_per_connection_mutex = Lock() class _NumericType(object): "Base for MySQL numeric types." @@ -954,6 +960,7 @@ class MySQLExecutionContext(default.DefaultExecutionContext): class MySQLDialect(ansisql.ANSIDialect): def __init__(self, **kwargs): ansisql.ANSIDialect.__init__(self, default_paramstyle='format', **kwargs) + self.per_connection = weakref.WeakKeyDictionary() def dbapi(cls): import MySQLdb as mysql @@ -1064,7 +1071,7 @@ class MySQLDialect(ansisql.ANSIDialect): def get_default_schema_name(self): if not hasattr(self, '_default_schema_name'): - self._default_schema_name = text("select database()", self).scalar() + self._default_schema_name = sql.text("select database()", self).scalar() return self._default_schema_name def has_table(self, connection, table_name, schema=None): @@ -1101,28 +1108,20 @@ class MySQLDialect(ansisql.ANSIDialect): """Load column definitions from the server.""" decode_from = self._detect_charset(connection) - - # reference: - # http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html - row = _compat_fetch(connection.execute( - "SHOW VARIABLES LIKE 'lower_case_table_names'"), - one=True, charset=decode_from) - if not row: - case_sensitive = True - else: - case_sensitive = row[1] in ('0', 'OFF' 'off') + case_sensitive = self._detect_case_sensitive(connection, decode_from) if not case_sensitive: table.name = table.name.lower() table.metadata.tables[table.name]= table try: - rp = connection.execute("describe " + self._escape_table_name(table), - {}) - except: - raise exceptions.NoSuchTableError(table.fullname) + rp = connection.execute("DESCRIBE " + self._escape_table_name(table)) + except exceptions.SQLError, e: + if e.orig.args[0] == 1146: + raise exceptions.NoSuchTableError(table.fullname) + raise - for row in _compat_fetch(rp, charset=decode_from): + for row in _compat_fetchall(rp, charset=decode_from): (name, type, nullable, primary_key, default) = \ (row[0], row[1], row[2] == 'YES', row[3] == 'PRI', row[4]) @@ -1173,7 +1172,7 @@ class MySQLDialect(ansisql.ANSIDialect): """SHOW CREATE TABLE to get foreign key/table options.""" rp = connection.execute("SHOW CREATE TABLE " + self._escape_table_name(table), {}) - row = _compat_fetch(rp, one=True, charset=charset) + row = _compat_fetchone(rp, charset=charset) if not row: raise exceptions.NoSuchTableError(table.fullname) desc = row[1].strip() @@ -1198,7 +1197,7 @@ class MySQLDialect(ansisql.ANSIDialect): def _escape_table_name(self, table): if table.schema is not None: - return '`%s`.`%s`' % (table.schema. table.name) + return '`%s`.`%s`' % (table.schema, table.name) else: return '`%s`' % table.name @@ -1209,7 +1208,7 @@ class MySQLDialect(ansisql.ANSIDialect): # on a connection via set_character_set() rs = connection.execute("show variables like 'character_set%%'") - opts = dict([(row[0], row[1]) for row in _compat_fetch(rs)]) + opts = dict([(row[0], row[1]) for row in _compat_fetchall(rs)]) if 'character_set_results' in opts: return opts['character_set_results'] @@ -1223,13 +1222,41 @@ class MySQLDialect(ansisql.ANSIDialect): warnings.warn(RuntimeWarning("Could not detect the connection character set with this combination of MySQL server and MySQL-python. MySQL-python >= 1.2.2 is recommended. Assuming latin1.")) return 'latin1' -def _compat_fetch(rp, one=False, charset=None): + def _detect_case_sensitive(self, connection, charset=None): + """Sniff out identifier case sensitivity. + + Cached per-connection. This value can not change without a server + restart. + """ + # http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html + + _per_connection_mutex.acquire() + try: + raw_connection = connection.connection.connection + cache = self.per_connection.get(raw_connection, {}) + if 'lower_case_table_names' not in cache: + row = _compat_fetchone(connection.execute( + "SHOW VARIABLES LIKE 'lower_case_table_names'"), + charset=charset) + if not row: + cs = True + else: + cs = row[1] in ('0', 'OFF' 'off') + cache['lower_case_table_names'] = cs + self.per_connection[raw_connection] = cache + return cache.get('lower_case_table_names') + finally: + _per_connection_mutex.release() + +def _compat_fetchall(rp, charset=None): """Proxy result rows to smooth over MySQL-Python driver inconsistencies.""" - if one: - return _MySQLPythonRowProxy(rp.fetchone(), charset) - else: - return [_MySQLPythonRowProxy(row, charset) for row in rp.fetchall()] + return [_MySQLPythonRowProxy(row, charset) for row in rp.fetchall()] + +def _compat_fetchone(rp, charset=None): + """Proxy a result row to smooth over MySQL-Python driver inconsistencies.""" + + return _MySQLPythonRowProxy(rp.fetchone(), charset) class _MySQLPythonRowProxy(object): @@ -1240,7 +1267,7 @@ class _MySQLPythonRowProxy(object): self.charset = charset def __getitem__(self, index): item = self.rowproxy[index] - if isinstance(item, array): + if isinstance(item, _array): item = item.tostring() if self.charset and isinstance(item, unicode): return item.encode(self.charset) @@ -1248,7 +1275,7 @@ class _MySQLPythonRowProxy(object): return item def __getattr__(self, attr): item = getattr(self.rowproxy, attr) - if isinstance(item, array): + if isinstance(item, _array): item = item.tostring() if self.charset and isinstance(item, unicode): return item.encode(self.charset) diff --git a/test/engine/reflection.py b/test/engine/reflection.py index 85701599e8..d2ee3106cb 100644 --- a/test/engine/reflection.py +++ b/test/engine/reflection.py @@ -428,10 +428,6 @@ class ReflectionTest(PersistTest): finally: meta.drop_all(testbase.db) - # mysql throws its own exception for no such table, resulting in - # a sqlalchemy.SQLError instead of sqlalchemy.NoSuchTableError. - # this could probably be fixed at some point. - @testbase.unsupported('mysql') def test_nonexistent(self): self.assertRaises(NoSuchTableError, Table, 'fake_table', @@ -583,6 +579,23 @@ class SchemaTest(PersistTest): assert buf.index("CREATE TABLE someschema.table1") > -1 assert buf.index("CREATE TABLE someschema.table2") > -1 + @testbase.unsupported('sqlite') + def testcreate(self): + schema = testbase.db.url.database + metadata = MetaData(testbase.db) + table1 = Table('table1', metadata, + Column('col1', Integer, primary_key=True), + schema=schema) + table2 = Table('table2', metadata, + Column('col1', Integer, primary_key=True), + Column('col2', Integer, ForeignKey('%s.table1.col1' % schema)), + schema=schema) + metadata.create_all() + metadata.create_all(checkfirst=True) + metadata.clear() + table1 = Table('table1', metadata, autoload=True, schema=schema) + table2 = Table('table2', metadata, autoload=True, schema=schema) + metadata.drop_all() if __name__ == "__main__": testbase.main()