]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- reverted previous "strings instead of tuples" change due to more specific test...
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 30 Mar 2008 23:30:31 +0000 (23:30 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 30 Mar 2008 23:30:31 +0000 (23:30 +0000)
- changed cache decorator call on default_schema_name call to a connection.info specific one

lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/pool.py
lib/sqlalchemy/sql/compiler.py

index f09f08507e3505feae1127361acf043ec71e6357..18b236d1c8e1a4665fc8048a50ddef7150d656da 100644 (file)
@@ -157,6 +157,7 @@ import datetime, inspect, re, sys
 from array import array as _array
 
 from sqlalchemy import exceptions, logging, schema, sql, util
+from sqlalchemy.pool import connection_cache_decorator
 from sqlalchemy.sql import operators as sql_operators
 from sqlalchemy.sql import functions as sql_functions
 from sqlalchemy.sql import compiler
@@ -1542,13 +1543,9 @@ class MySQLDialect(default.DefaultDialect):
             return False
 
     def get_default_schema_name(self, connection):
-        try:
-            return connection.info['default_schema']
-        except KeyError:
-            connection.info['default_schema'] = schema = \
-              connection.execute('SELECT DATABASE()').scalar()
-            return schema
-
+        return connection.execute('SELECT DATABASE()').scalar()
+    get_default_schema_name = connection_cache_decorator(get_default_schema_name)
+    
     def table_names(self, connection, schema):
         """Return a Unicode SHOW TABLES from a given schema."""
 
index 8a2df7f3bd4b482eb7b61a082a6fe4aebaa4efec..2763972649e4405e3f157b5da7acce73d23cc940 100644 (file)
@@ -12,6 +12,7 @@ from sqlalchemy.engine import default, base
 from sqlalchemy.sql import compiler, visitors
 from sqlalchemy.sql import operators as sql_operators, functions as sql_functions
 from sqlalchemy import types as sqltypes
+from sqlalchemy.pool import connection_cache_decorator
 
 
 class OracleNumeric(sqltypes.Numeric):
@@ -380,8 +381,8 @@ class OracleDialect(default.DefaultDialect):
 
     def get_default_schema_name(self,connection):
         return connection.execute('SELECT USER FROM DUAL').scalar()
-    get_default_schema_name = util.cache_decorator(get_default_schema_name)
-    
+    get_default_schema_name = connection_cache_decorator(get_default_schema_name)
+
     def table_names(self, connection, schema):
         # note that table_names() isnt loading DBLINKed or synonym'ed tables
         if schema is None:
index 9dbf359b6883b7885febea360b8cf2c780d10644..abae27eb1021a8742fa8a213ed89e257d7f34d7c 100644 (file)
@@ -26,6 +26,7 @@ from sqlalchemy.engine import base, default
 from sqlalchemy.sql import compiler, expression
 from sqlalchemy.sql import operators as sql_operators
 from sqlalchemy import types as sqltypes
+from sqlalchemy.pool import connection_cache_decorator
 
 
 class PGInet(sqltypes.TypeEngine):
@@ -368,8 +369,8 @@ class PGDialect(default.DefaultDialect):
 
     def get_default_schema_name(self, connection):
         return connection.scalar("select current_schema()", None)
-    get_default_schema_name = util.cache_decorator(get_default_schema_name)
-    
+    get_default_schema_name = connection_cache_decorator(get_default_schema_name)
+
     def last_inserted_ids(self):
         if self.context.last_inserted_ids is None:
             raise exceptions.InvalidRequestError("no INSERT executed, or can't use cursor.lastrowid without Postgres OIDs enabled")
index 94d9127f0cd061fecd6c26f4e5d18b5a39178ba7..e22d1d8d37f912039fe64f561ff160ae6f9eea39 100644 (file)
@@ -58,6 +58,21 @@ def clear_managers():
         manager.close()
     proxies.clear()
 
+def connection_cache_decorator(func):
+    """apply caching to the return value of a function, using
+    the 'info' collection on its given connection."""
+
+    name = func.__name__
+
+    def do_with_cache(self, connection):
+        try:
+            return connection.info[name]
+        except KeyError:
+            value = func(self, connection)
+            connection.info[name] = value
+            return value
+    return do_with_cache
+    
 class Pool(object):
     """Base class for connection pools.
 
index 71d12b6be878adc9b53f93ce9c1ddfc1840982e2..76e2ca2608eab5b9ebd388d67d098ac55624f8e9 100644 (file)
@@ -418,9 +418,8 @@ class DefaultCompiler(engine.Compiled):
         return bind_name
 
     def _truncated_identifier(self, ident_class, name):
-        key = ident_class + "|" + name
-        if key in self.generated_ids:
-            return self.generated_ids[key]
+        if (ident_class, name) in self.generated_ids:
+            return self.generated_ids[(ident_class, name)]
 
         anonname = ANONYMOUS_LABEL.sub(self._process_anon, name)
 
@@ -430,19 +429,18 @@ class DefaultCompiler(engine.Compiled):
             self.generated_ids[ident_class] = counter + 1
         else:
             truncname = anonname
-        self.generated_ids[key] = truncname
+        self.generated_ids[(ident_class, name)] = truncname
         return truncname
 
     def _process_anon(self, match):
         (ident, derived) = match.group(1,2)
-        key = 'anonymous|' + ident
+        key = ('anonymous', ident)
         if key in self.generated_ids:
             return self.generated_ids[key]
         else:
-            counter_key = "anon_counter|" + derived
-            anonymous_counter = self.generated_ids.get(counter_key, 1)
+            anonymous_counter = self.generated_ids.get(('anon_counter', derived), 1)
             newname = derived + "_" + str(anonymous_counter)
-            self.generated_ids[counter_key] = anonymous_counter + 1
+            self.generated_ids[('anon_counter', derived)] = anonymous_counter + 1
             self.generated_ids[key] = newname
             return newname