]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- ResultProxy and friends always reference the DBAPI connection at the same time
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 5 Dec 2010 05:46:11 +0000 (00:46 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 5 Dec 2010 05:46:11 +0000 (00:46 -0500)
as the cursor.  There is no reason for CursorFairy - the only use case would be,
end-user is using the pool or pool.manage with DBAPI connections, uses a cursor,
deferences the owning connection and continues using cursor.  This is an almost
nonexistent use case and isn't correct usage at a DBAPI level.  Take out CursorFairy.
- move the "check for a dot in the colname" logic out to the sqlite dialect.

lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/pool.py
test/sql/test_quote.py

index 261793a33c3e42e18f885c843bd313f878c6e0e6..a74ea0c3c93eb170ffcffd8ca79667393a42fa50 100644 (file)
@@ -56,8 +56,7 @@ import datetime, re, time
 
 from sqlalchemy import schema as sa_schema
 from sqlalchemy import sql, exc, pool, DefaultClause
-from sqlalchemy.engine import default
-from sqlalchemy.engine import reflection
+from sqlalchemy.engine import default, base, reflection
 from sqlalchemy import types as sqltypes
 from sqlalchemy import util
 from sqlalchemy.sql import compiler, functions as sql_functions
@@ -335,6 +334,20 @@ class SQLiteIdentifierPreparer(compiler.IdentifierPreparer):
             result = self.quote_schema(index.table.schema, index.table.quote_schema) + "." + result
         return result
 
+class SQLiteExecutionContext(default.DefaultExecutionContext):
+    def get_result_proxy(self):
+        rp = base.ResultProxy(self)
+        if rp._metadata:
+            # adjust for dotted column names.  SQLite
+            # in the case of UNION may store col names as 
+            # "tablename.colname"
+            # in cursor.description
+            for colname in rp._metadata.keys:
+                if "." in colname:
+                    trunc_col = colname.split(".")[1]
+                    rp._metadata._set_keymap_synonym(trunc_col, colname)
+        return rp
+    
 class SQLiteDialect(default.DefaultDialect):
     name = 'sqlite'
     supports_alter = False
@@ -352,7 +365,8 @@ class SQLiteDialect(default.DefaultDialect):
     ischema_names = ischema_names
     colspecs = colspecs
     isolation_level = None
-
+    execution_ctx_cls = SQLiteExecutionContext
+    
     supports_cast = True
     supports_default_values = True
 
index 90d6bda86d69416c5664a3a005de0faac563d124..2672499712b03a019f53ae8116b7e26e60f86c05 100644 (file)
@@ -1319,6 +1319,23 @@ class Connection(Connectable):
             self.close()
         
         return r
+    
+    def _safe_close_cursor(self, cursor):
+        """Close the given cursor, catching exceptions
+        and turning into log warnings.
+        
+        """
+        try:
+            cursor.close()
+        except Exception, e:
+            try:
+                ex_text = str(e)
+            except TypeError:
+                ex_text = repr(e)
+            self.connection._logger.warn("Error closing cursor: %s", ex_text)
+
+            if isinstance(e, (SystemExit, KeyboardInterrupt)):
+                raise
         
     def _handle_dbapi_exception(self, 
                                     e, 
@@ -1347,7 +1364,7 @@ class Connection(Connectable):
                 self.engine.dispose()
             else:
                 if cursor:
-                    cursor.close()
+                    self._safe_close_cursor(cursor)
                 self._autorollback()
                 if self.should_close_with_result:
                     self.close()
@@ -2163,7 +2180,6 @@ class ResultMetaData(object):
         # saved attribute lookup self._processors)
         self._keymap = keymap = {}
         self.keys = []
-        self._echo = parent._echo
         context = parent.context
         dialect = context.dialect
         typemap = dialect.dbapi_type_map
@@ -2172,14 +2188,6 @@ class ResultMetaData(object):
             if dialect.description_encoding:
                 colname = colname.decode(dialect.description_encoding)
 
-            if '.' in colname:
-                # sqlite will in some circumstances prepend table name to
-                # colnames, so strip
-                origname = colname
-                colname = colname.split('.')[-1]
-            else:
-                origname = None
-
             if context.result_map:
                 try:
                     name, obj, type_ = context.result_map[colname.lower()]
@@ -2208,11 +2216,6 @@ class ResultMetaData(object):
                 # or the more precise ColumnElement)
                 keymap[name.lower()] = (processor, None)
 
-            # store the "origname" if we truncated (sqlite only)
-            if origname and \
-                    keymap.setdefault(origname.lower(), rec) is not rec:
-                keymap[origname.lower()] = (processor, None)
-            
             if dialect.requires_name_normalize:
                 colname = dialect.normalize_name(colname)
                 
@@ -2221,11 +2224,22 @@ class ResultMetaData(object):
                 for o in obj:
                     keymap[o] = rec
 
-        if self._echo:
-            self.logger = context.engine.logger
-            self.logger.debug(
+        if parent._echo:
+            context.engine.logger.debug(
                 "Col %r", tuple(x[0] for x in metadata))
-
+    
+    def _set_keymap_synonym(self, name, origname):
+        """Set a synonym for the given name.
+        
+        Some dialects (SQLite at the moment) may use this to 
+        adjust the column names that are significant within a
+        row.
+        
+        """
+        rec = (processor, i) = self._keymap[origname.lower()]
+        if self._keymap.setdefault(name, rec) is not rec:
+            self._keymap[name] = (processor, None)
+        
     def _key_fallback(self, key):
         map = self._keymap
         result = None
@@ -2413,7 +2427,7 @@ class ResultProxy(object):
 
         if not self.closed:
             self.closed = True
-            self.cursor.close()
+            self.connection._safe_close_cursor(self.cursor)
             if _autoclose_connection and \
                 self.connection.should_close_with_result:
                 self.connection.close()
index 587dbf92ff738d7d194e29ab01f131e996bd9a95..6c708aa52c26110860b8229d86c5b900d498db2c 100644 (file)
@@ -405,7 +405,7 @@ class DefaultExecutionContext(base.ExecutionContext):
             self.parameters = self.__encode_param_keys(parameters)
             self.executemany = len(parameters) > 1
             
-            if isinstance(statement, unicode) and not dialect.supports_unicode_statements:
+            if not dialect.supports_unicode_statements and isinstance(statement, unicode):
                 self.unicode_statement = statement
                 self.statement = statement.encode(self.dialect.encoding)
             else:
index 387ef830d13ac68d4055af89ecc7c9fc46ea5efc..02d56dead7b531f2d084515751a60c7d2e36df44 100644 (file)
@@ -419,8 +419,7 @@ class _ConnectionFairy(object):
 
     def cursor(self, *args, **kwargs):
         try:
-            c = self.connection.cursor(*args, **kwargs)
-            return _CursorFairy(self, c)
+            return self.connection.cursor(*args, **kwargs)
         except Exception, e:
             self.invalidate(e=e)
             raise
@@ -487,42 +486,6 @@ class _ConnectionFairy(object):
         self.connection = None
         self._connection_record = None
 
-class _CursorFairy(object):
-    __slots__ = '_parent', 'cursor', 'execute'
-
-    def __init__(self, parent, cursor):
-        self._parent = parent
-        self.cursor = cursor
-        self.execute = cursor.execute
-        
-    def invalidate(self, e=None):
-        self._parent.invalidate(e=e)
-    
-    def __iter__(self):
-        return iter(self.cursor)
-        
-    def close(self):
-        try:
-            self.cursor.close()
-        except Exception, e:
-            try:
-                ex_text = str(e)
-            except TypeError:
-                ex_text = repr(e)
-            self._parent._logger.warn("Error closing cursor: %s", ex_text)
-
-            if isinstance(e, (SystemExit, KeyboardInterrupt)):
-                raise
-    
-    def __setattr__(self, key, value):
-        if key in self.__slots__:
-            object.__setattr__(self, key, value)
-        else:
-            setattr(self.cursor, key, value)
-            
-    def __getattr__(self, key):
-        return getattr(self.cursor, key)
-
 class SingletonThreadPool(Pool):
     """A Pool that maintains one connection per thread.
 
index 8f27a7b3c15a05f16f4f11b22f748864bd84626d..e880388d7ee01ff77f8a78ddd0f52764d9e22d73 100644 (file)
@@ -140,7 +140,9 @@ class QuoteTest(TestBase, AssertsCompiledSQL):
         if labels arent quoted, a query in postgresql in particular will fail since it produces:
 
         SELECT LaLa.lowercase, LaLa."UPPERCASE", LaLa."MixedCase", LaLa."ASC"
-        FROM (SELECT DISTINCT "WorstCase1".lowercase AS lowercase, "WorstCase1"."UPPERCASE" AS UPPERCASE, "WorstCase1"."MixedCase" AS MixedCase, "WorstCase1"."ASC" AS ASC \nFROM "WorstCase1") AS LaLa
+        FROM (SELECT DISTINCT "WorstCase1".lowercase AS lowercase, 
+                "WorstCase1"."UPPERCASE" AS UPPERCASE, 
+                "WorstCase1"."MixedCase" AS MixedCase, "WorstCase1"."ASC" AS ASC \nFROM "WorstCase1") AS LaLa
 
         where the "UPPERCASE" column of "LaLa" doesnt exist.
         """