]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Revamped the Connection memoize decorator a bit, moved to engine
authorJason Kirtland <jek@discorporate.us>
Wed, 2 Apr 2008 11:39:26 +0000 (11:39 +0000)
committerJason Kirtland <jek@discorporate.us>
Wed, 2 Apr 2008 11:39:26 +0000 (11:39 +0000)
- MySQL character set caching is more aggressive but will invalidate the cache if a SET is issued.
- MySQL connection memos are namespaced: info[('mysql', 'server_variable')]

CHANGES
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/pool.py
lib/sqlalchemy/util.py
test/dialect/mysql.py

diff --git a/CHANGES b/CHANGES
index 4797c3fc534d4828435148968abcabfd27426c79..7207f1322a7af946dc4147d846d7cc90e239ec38 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -253,6 +253,10 @@ CHANGES
      - Improvements to pyodbc + Unix. If you couldn't get that
        combination to work before, please try again.
 
+- mysql
+     - The connection.info keys the dialect uses to cache server
+       settings have changed and are now namespaced.
+
 0.4.4
 ------
 - sql
index 18b236d1c8e1a4665fc8048a50ddef7150d656da..c4b8af68489afa5ee35a48c925759e30aab6e6a3 100644 (file)
@@ -157,7 +157,6 @@ import datetime, inspect, re, sys
 from array import array as _array
 
 from sqlalchemy import exceptions, logging, schema, sql, util
-from sqlalchemy.pool import connection_cache_decorator
 from sqlalchemy.sql import operators as sql_operators
 from sqlalchemy.sql import functions as sql_functions
 from sqlalchemy.sql import compiler
@@ -224,6 +223,10 @@ AUTOCOMMIT_RE = re.compile(
 SELECT_RE = re.compile(
     r'\s*(?:SELECT|SHOW|DESCRIBE|XA RECOVER)',
     re.I | re.UNICODE)
+SET_RE = re.compile(
+    r'\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w',
+    re.I | re.UNICODE)
+
 
 class _NumericType(object):
     """Base for MySQL numeric types."""
@@ -1396,6 +1399,11 @@ class MySQLExecutionContext(default.DefaultExecutionContext):
                 self._last_inserted_ids[0] is None):
                 self._last_inserted_ids = ([self.cursor.lastrowid] +
                                            self._last_inserted_ids[1:])
+        elif (not self.isupdate and not self.should_autocommit and
+              self.statement and SET_RE.match(self.statement)):
+            # This misses if a user forces autocommit on text('SET NAMES'),
+            # which is probably a programming error anyhow.
+            self.connection.info.pop(('mysql', 'charset'), None)
 
     def returns_rows_text(self, statement):
         return SELECT_RE.match(statement)
@@ -1544,8 +1552,9 @@ class MySQLDialect(default.DefaultDialect):
 
     def get_default_schema_name(self, connection):
         return connection.execute('SELECT DATABASE()').scalar()
-    get_default_schema_name = connection_cache_decorator(get_default_schema_name)
-    
+    get_default_schema_name = engine_base.connection_memoize(
+        ('dialect', 'default_schema_name'))(get_default_schema_name)
+
     def table_names(self, connection, schema):
         """Return a Unicode SHOW TABLES from a given schema."""
 
@@ -1599,12 +1608,9 @@ class MySQLDialect(default.DefaultDialect):
         cached per-Connection.
         """
 
-        try:
-            return connection.info['_mysql_server_version_info']
-        except KeyError:
-            version = connection.info['_mysql_server_version_info'] = \
-              self._server_version_info(connection.connection.connection)
-            return version
+        return self._server_version_info(connection.connection.connection)
+    server_version_info = engine_base.connection_memoize(
+        ('mysql', 'server_version_info'))(server_version_info)
 
     def _server_version_info(self, dbapi_con):
         """Convert a MySQL-python server_info string into a tuple."""
@@ -1654,7 +1660,7 @@ class MySQLDialect(default.DefaultDialect):
             columns = self._describe_table(connection, table, charset)
             sql = reflector._describe_to_create(table, columns)
 
-        self._adjust_casing(connection, table, charset)
+        self._adjust_casing(connection, table)
 
         return reflector.reflect(connection, table, sql, charset,
                                  only=include_columns)
@@ -1662,10 +1668,7 @@ class MySQLDialect(default.DefaultDialect):
     def _adjust_casing(self, connection, table, charset=None):
         """Adjust Table name to the server case sensitivity, if needed."""
 
-        if charset is None:
-            charset = self._detect_charset(connection)
-
-        casing = self._detect_casing(connection, charset)
+        casing = self._detect_casing(connection)
 
         # For winxx database hosts.  TODO: is this really needed?
         if casing == 1 and table.name != table.name.lower():
@@ -1678,8 +1681,8 @@ class MySQLDialect(default.DefaultDialect):
         """Sniff out the character set in use for connection results."""
 
         # Allow user override, won't sniff if force_charset is set.
-        if 'force_charset' in connection.info:
-            return connection.info['force_charset']
+        if ('mysql', 'force_charset') in connection.info:
+            return connection.info[('mysql', 'force_charset')]
 
         # Note: MySQL-python 1.2.1c7 seems to ignore changes made
         # on a connection via set_character_set()
@@ -1714,55 +1717,56 @@ class MySQLDialect(default.DefaultDialect):
                     "combination of MySQL server and MySQL-python. "
                     "MySQL-python >= 1.2.2 is recommended.  Assuming latin1.")
                 return 'latin1'
+    _detect_charset = engine_base.connection_memoize(
+        ('mysql', 'charset'))(_detect_charset)
+
 
-    def _detect_casing(self, connection, charset=None):
+    def _detect_casing(self, connection):
         """Sniff out identifier case sensitivity.
 
         Cached per-connection. This value can not change without a server
         restart.
-        """
 
+        """
         # http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html
 
-        try:
-            return connection.info['lower_case_table_names']
-        except KeyError:
-            row = _compat_fetchone(connection.execute(
-                    "SHOW VARIABLES LIKE 'lower_case_table_names'"),
-                                   charset=charset)
-            if not row:
+        charset = self._detect_charset(connection)
+        row = _compat_fetchone(connection.execute(
+            "SHOW VARIABLES LIKE 'lower_case_table_names'"),
+                               charset=charset)
+        if not row:
+            cs = 0
+        else:
+            # 4.0.15 returns OFF or ON according to [ticket:489]
+            # 3.23 doesn't, 4.0.27 doesn't..
+            if row[1] == 'OFF':
                 cs = 0
+            elif row[1] == 'ON':
+                cs = 1
             else:
-                # 4.0.15 returns OFF or ON according to [ticket:489]
-                # 3.23 doesn't, 4.0.27 doesn't..
-                if row[1] == 'OFF':
-                    cs = 0
-                elif row[1] == 'ON':
-                    cs = 1
-                else:
-                    cs = int(row[1])
-                row.close()
-            connection.info['lower_case_table_names'] = cs
-            return cs
+                cs = int(row[1])
+            row.close()
+        return cs
+    _detect_casing = engine_base.connection_memoize(
+        ('mysql', 'lower_case_table_names'))(_detect_casing)
 
-    def _detect_collations(self, connection, charset=None):
+    def _detect_collations(self, connection):
         """Pull the active COLLATIONS list from the server.
 
         Cached per-connection.
         """
 
-        try:
-            return connection.info['collations']
-        except KeyError:
-            collations = {}
-            if self.server_version_info(connection) < (4, 1, 0):
-                pass
-            else:
-                rs = connection.execute('SHOW COLLATION')
-                for row in _compat_fetchall(rs, charset):
-                    collations[row[0]] = row[1]
-            connection.info['collations'] = collations
-            return collations
+        collations = {}
+        if self.server_version_info(connection) < (4, 1, 0):
+            pass
+        else:
+            charset = self._detect_charset(connection)
+            rs = connection.execute('SHOW COLLATION')
+            for row in _compat_fetchall(rs, charset):
+                collations[row[0]] = row[1]
+        return collations
+    _detect_collations = engine_base.connection_memoize(
+        ('mysql', 'collations'))(_detect_collations)
 
     def use_ansiquotes(self, useansi):
         self._use_ansiquotes = useansi
index b23ec06e519a1471161bf45905b93dfc6d02b4cd..2faeab65144ac63d5bd6cf9c1b7583bd3ebb054a 100644 (file)
@@ -12,7 +12,6 @@ from sqlalchemy.engine import default, base
 from sqlalchemy.sql import compiler, visitors
 from sqlalchemy.sql import operators as sql_operators, functions as sql_functions
 from sqlalchemy import types as sqltypes
-from sqlalchemy.pool import connection_cache_decorator
 
 
 class OracleNumeric(sqltypes.Numeric):
@@ -379,9 +378,10 @@ class OracleDialect(default.DefaultDialect):
         else:
             return name.encode(self.encoding)
 
-    def get_default_schema_name(self,connection):
+    def get_default_schema_name(self, connection):
         return connection.execute('SELECT USER FROM DUAL').scalar()
-    get_default_schema_name = connection_cache_decorator(get_default_schema_name)
+    get_default_schema_name = base.connection_memoize(
+        ('dialect', 'default_schema_name'))(get_default_schema_name)
 
     def table_names(self, connection, schema):
         # note that table_names() isnt loading DBLINKed or synonym'ed tables
index abae27eb1021a8742fa8a213ed89e257d7f34d7c..94ad7d2e45d36204b0c5e8e4bcee6b39c0e0592f 100644 (file)
@@ -26,7 +26,6 @@ from sqlalchemy.engine import base, default
 from sqlalchemy.sql import compiler, expression
 from sqlalchemy.sql import operators as sql_operators
 from sqlalchemy import types as sqltypes
-from sqlalchemy.pool import connection_cache_decorator
 
 
 class PGInet(sqltypes.TypeEngine):
@@ -369,7 +368,8 @@ class PGDialect(default.DefaultDialect):
 
     def get_default_schema_name(self, connection):
         return connection.scalar("select current_schema()", None)
-    get_default_schema_name = connection_cache_decorator(get_default_schema_name)
+    get_default_schema_name = base.connection_memoize(
+        ('dialect', 'default_schema_name'))(get_default_schema_name)
 
     def last_inserted_ids(self):
         if self.context.last_inserted_ids is None:
index 3d7e651983779740a1f4fff6726e9f1beb39a8d8..cd662ac92e833eb3b11a9d4cb5d5058240a9ee0f 100644 (file)
@@ -12,7 +12,7 @@ higher-level statement-construction, connection-management, execution
 and result contexts.
 """
 
-import StringIO, sys
+import inspect, StringIO, sys
 from sqlalchemy import exceptions, schema, util, types, logging
 from sqlalchemy.sql import expression
 
@@ -1864,3 +1864,27 @@ class DefaultRunner(schema.SchemaVisitor):
             return default.arg(self.context)
         else:
             return default.arg
+
+
+def connection_memoize(key):
+    """Decorator, memoize a function in a connection.info stash.
+
+    Only applicable to functions which take no arguments other than a
+    connection.  The memo will be stored in ``connection.info[key]``.
+
+    """
+    def decorate(fn):
+        spec = inspect.getargspec(fn)
+        assert len(spec[0]) == 2
+        assert spec[0][1] == 'connection'
+        assert spec[1:3] == (None, None)
+
+        def decorated(self, connection):
+            try:
+                return connection.info[key]
+            except KeyError:
+                connection.info[key] = val = fn(self, connection)
+                return val
+
+        return util.function_named(decorated, fn.__name__)
+    return decorate
index e22d1d8d37f912039fe64f561ff160ae6f9eea39..94d9127f0cd061fecd6c26f4e5d18b5a39178ba7 100644 (file)
@@ -58,21 +58,6 @@ def clear_managers():
         manager.close()
     proxies.clear()
 
-def connection_cache_decorator(func):
-    """apply caching to the return value of a function, using
-    the 'info' collection on its given connection."""
-
-    name = func.__name__
-
-    def do_with_cache(self, connection):
-        try:
-            return connection.info[name]
-        except KeyError:
-            value = func(self, connection)
-            connection.info[name] = value
-            return value
-    return do_with_cache
-    
 class Pool(object):
     """Base class for connection pools.
 
index 26b6dbe9a060bc3a56a57e9553ac4765e92fb969..101ef1462c9c0d792508e03ae3256746af00bdd1 100644 (file)
@@ -4,7 +4,7 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-import inspect, itertools, sets, sys, warnings, weakref
+import inspect, itertools, new, sets, sys, warnings, weakref
 import __builtin__
 types = __import__('types')
 
@@ -1055,7 +1055,22 @@ class symbol(object):
             return sym
         finally:
             symbol._lock.release()
-            
+
+
+def function_named(fn, name):
+    """Return a function with a given __name__.
+
+    Will assign to __name__ and return the original function if possible on
+    the Python implementation, otherwise a new function will be constructed.
+
+    """
+    try:
+        fn.__name__ = name
+    except TypeError:
+        fn = new.function(fn.func_code, fn.func_globals, name,
+                          fn.func_defaults, fn.func_closure)
+    return fn
+
 def conditional_cache_decorator(func):
     """apply conditional caching to the return value of a function."""
 
@@ -1166,8 +1181,5 @@ def _decorate_with_warning(func, wtype, message, docstring_header=None):
 
     func_with_warning.__doc__ = doc
     func_with_warning.__dict__.update(func.__dict__)
-    try:
-        func_with_warning.__name__ = func.__name__
-    except TypeError:
-        pass
-    return func_with_warning
+
+    return function_named(func_with_warning, func.__name__)
index e1bd47d29113e0eac1ceba19de0d3fa572627ef7..00478908ef8c8d46e174c27890e0ba88a49c82ff 100644 (file)
@@ -927,6 +927,49 @@ class SQLTest(TestBase, AssertsCompiledSQL):
             self.assert_compile(cast(t.c.col, type_), expected)
 
 
+class ExecutionTest(TestBase):
+    """Various MySQL execution special cases."""
+
+    __only_on__ = 'mysql'
+
+    def test_charset_caching(self):
+        engine = engines.testing_engine()
+
+        cx = engine.connect()
+        meta = MetaData()
+
+        assert ('mysql', 'charset') not in cx.info
+        assert ('mysql', 'force_charset') not in cx.info
+
+        cx.execute(text("SELECT 1")).fetchall()
+        assert ('mysql', 'charset') not in cx.info
+
+        meta.reflect(cx)
+        assert ('mysql', 'charset') in cx.info
+
+        cx.execute(text("SET @squiznart=123"))
+        assert ('mysql', 'charset') in cx.info
+
+        # the charset invalidation is very conservative
+        cx.execute(text("SET TIMESTAMP = DEFAULT"))
+        assert ('mysql', 'charset') not in cx.info
+
+        cx.info[('mysql', 'force_charset')] = 'latin1'
+
+        assert engine.dialect._detect_charset(cx) == 'latin1'
+        assert cx.info[('mysql', 'charset')] == 'latin1'
+
+        del cx.info[('mysql', 'force_charset')]
+        del cx.info[('mysql', 'charset')]
+
+        meta.reflect(cx)
+        assert ('mysql', 'charset') in cx.info
+
+        # String execution doesn't go through the detector.
+        cx.execute("SET TIMESTAMP = DEFAULT")
+        assert ('mysql', 'charset') in cx.info
+
+
 def colspec(c):
     return testing.db.dialect.schemagenerator(testing.db.dialect,
         testing.db, None, None).get_column_specification(c)